The purpose is to further improve PyTorch CPU performance on both imperative path and jit path.
MKLDNN requires to reorder memory from plain
layout to blocked
layout to achieve optimal performance on CPU, e.g. from nchw
to nChw16c
, etc. At this moment on PyTorch, MKLDNN operators reuse CPU tensor, which means for each MKLDNN operator, it takes three steps to finish the computation:
input_reorder(plain_layout, blocked_layout)
mkldnn_computation()
output_reorder(blocked_layout, plain_layout)
These reorder
s takes about 50% of time on a typical ImageNet topology, e.g. ResNet50
. Also MKLDNN chose different blocked
format according to different input config from Convolution
, with nn.Conv2d
always output in plain
layout, subsequent layers (BatchNorm
, Pooling
) would only execute on plain
layout and this is the slow path for MKLDNN. With these problems solved, the CNN models would have 3~4x speedup v.s. current performance.
The general idea to solve these problems here is to expose MKLDNN blocked
layout among MKLDNN operators. Which means, within MKLDNN domain, Tensor can be transfered in blocked
layout. But wherever the Tensor is used outside MKLDNN domain, it should be reordered back to plain
layout.
The final picture might be like: In case we have an model, user can transfer weights
to mkldnn
by model.mkldnn()
, similar to model.cuda()
and this would allow weight tensor to be in blocked
layout to avoid unnecessary reorders. The transition can be automatic and explicit which means it is transparent for user and user can also choose to manually make the transition.
To achieve that, we can roughly take 3 steps:
This means a separate TensorTypeId
for mkldnn
, e.g. MkldnnCPUTensorId
. It can be implemented in two ways:
- Opaque data handle under
TensorImpl
: Onging effort can be fould here, anOpaque
data handle is registered underTensorImpl
. This Opaque handle would benullptr
forCPUTensorId
tensor and holdmkldnn::memory
forMkldnnCPUTensorId
. - Deriving
TensorImpl
: e.g.MKLDNNTensorImpl
and underlying storage can bemkldnn::memory
and memory descriptor.
Aside from the implementation details, it requires only two things here:
- Python interface:
mkldnn
tensor andcpu
tensor can be transfered from one to another:
mkldnn_tensor = cpu_tensor.to_mkldnn()
cpu_tensor = mkldnn_tensor.to_dense() # treate mkldnn as a layout for CPU
- C++ inferface:
mkldnn
tensor can be reordered to inblocked
layout
// itensor is ideep::tensor which wraps mkldnn::memory
// atensor is at::Tensor
auto itensor = get_itensor_mkldnn(atensor)
itensor.reorder_to(blocked_layout_descriptor)
This means to expose blocked
layout for weights so as to remove redundant weight reorders. Generally idea is keep a list of MKLDNN supported operators, e.g. mkldnn_supported_ops
and traverse model recursively and convert weight of supported operator to mkldnn
.
- Python interface: add
mkldnn()
method underModule.py
, similar tocuda()
# prototype pseudo
# torch/backends/mkldnn/__init__.py
mkldnn_supported_ops = {nn.Conv2d, nn.BatchNorm2d, nn.ReLU}
# torch/nn/modules/module.py
def mkldnn():
def convert_weight_to_mkldnn(m):
if type(m) in mkldnn_supported_ops:
if m.weight is not None:
m.weight.data.to_mkldnn()
return self._apply(convert_weight_to_mkldnn)
- Inference: weight will be reordered only once for first run and
blocked
layout will be used for subsequent runs. - Training: weight will be in
blocked
layout for the whole training process which means optimizer need to be able to ingestmkldnn
tensor directly. We can registermkldnn_sgd_step
andmkldnn_adam_step
innative_functions.yaml
to fulfill this.
Also with weight cache, mkldnn-rnn won't need to cache the transposed weights at Module
level which is a much more decent way.
And before saving model weights to disk, model needs to be reorder back to plain
layout.
This means to allow adjacent mkldnn
supported operators to pass blocked
layout in between so as to remove redundant reorders from output/input. The transition is transparent to end users which user always have full control to choose which layer to be computed by mkldnn
.
- Python inferface:
# prototype pseudo:
# provide two transition operators
def Dense2Mkldnn(input):
return input if input.is_mkldnn() else input.to_mkldnn()
def Mkldnn2Dense(input):
return input if not input.is_mkldnn() else input.to_dense()
Uesr can manually enable/disable mkldnn
in model definition for debugging purposes:
Dense2Mkldnn() -> Conv2d() -> BatchNorm2d() -> Mkldnn2Dense()
The behavior of underlying implentations need to be changed as well, should allow to output a mkldnn
tensor in case the input is mkldnn
.
### for example:
mkldnn_operator(dense_input) -> (dense_output)
mkldnn_operator(mkldnn_input) -> (mkldnn_output)
Also we can provide a method to transit the model both automatically and explicitly, this process is also transparent to users.
For MKLDNN supported operators, OP -> Dense2Mkldnn() + OP
; for MKLDNN non-supported operators, OP -> Mkldnn2Dense() + OP
.
We can also add a short parth for container (nn.Sequential
), for example, we can transfer alexnet.features to
self.features = nn.Sequential(
nn.Dense2Mkldnn(),
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Mkldnn2Dense(),
In this way, supposing mkldnn
supports all the operators above, blocked
format will be used through the whole list. And reorders from output to input is actually not needed.