MKL-DNN requires specific format for weight to do convolution faster. By freezing weight inside a TorchScript, one could embed more information of weight tensor into graph and use constant propagation to propagate its content closer to its use, possibly avoid the whole computation of transformation in runtime.
We insert ops before aten::conv2d to transform weight format in favour of MKL-DNN computation, for example.
We got an IR:
%30 : Float(*, *, *, *) = prim::GetAttr[name="weight"]
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30, %4, %611, %612, %613, %23)
We optimize a little bit of it by adding:
%30 : Float(*, *, *, *) = prim::GetAttr[name="weight"]
%30.weight: Float(*, *, *, *) = some::reorder(%30)
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)
However the runtime overhead of transformation was still there when the graph was evaluated. We would like it to be:
%672.weight : Float(16, 3, 3, 3) = prim::Constant[value=<Tensor>]()
%30.weight: Float(*, *, *, *) = some::reorder(%672.weight)
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)
After constant propagation (also DCE) the graph would be like:
%30.weight : Tensor = prim::Constant[value=<Tensor>]()
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)
With the Tensor being an MKL-DNN opaque tensor.
Unfortunately we couldn't freeze params in standard registered pass because we don't have reference to its Module inside it. The provided code files exposed a function called _jit_pass_freeze_params to change prim::GetAttr to Constant by poke inside a Module and grab the data out. You could freeze params like:
ConvBnRelu = ScriptedCascadedConv2dBnRelu(3, 16, 32, kernel_size = 3, stride = 1)
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'weight')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'bias')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_mean')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_var')
After running this code, ConvBnRelu.graph will treat internal weight/bias/running_mean/running_var as 'Constant' instead of GetAttr primitives which will allow constant propagation pass to do optimization available.