("mkldnn" has been renamed to "oneDNN", but exsiting PyTorch APIs still use "mkldnn", future work will align PyTorch user level APIs to "oneDNN")
- PyTorch Channels Last memory format introduction
- oneDNN API for NHWC layout
- Generic Channels Last memory format optimization with ATen native
- oneDNN NHWC integration
NB: Memory format refers to data representation that describes how multidimensional arrays (nD) are stored in linear (1D) memory address space. Memory format has the same semantic with layout in oneDNN. Layout in PyTorch has other semantic ofdescribing dense or sparse with the attributes: 'torch.strided', 'torch.sparse_coo'.
On CNN models, the canonical order of tensor dimensions are assigned with semantic meaning. For example the input tensor of 2D convolution is of NCHW by default on PyTorch - <batch_size, channels, height, width>. NHWC is an alternative way of describing the tensor dimensions - <batch_size, height, width, channels>.
Take a look at the following image of illustrating NCHW and NHWC when N=1. Actually when N=1, NHWC has the same format with BMP file image.
PyTorch refers NCHW as torch.contiguous_format
which is the default memory format and NHWC as torch.channels_last
which is an new feature from 1.5 release.
TF takes NHWC as the default memory format and from the performance point of view NHWC has advantage over NCHW. On CPU platform, we propose to optimize Channels Last memory path out of the following reasones:
- Performance - NHWC performance is not as good as blocked memory format (nChw16c) but it is close, and much better than NCHW.
- User Experience - Operator coverage of NHWC would be higher than blocked memory format (
to_mkldnn()
method) so user experience is better. To be specific it would be very difficult to enable operator that manipulatesdim
on blocked format such assum(dim=?)
so you need to convert tensor from blocked memory format back to NHWC byto_dense()
before feeding intosum()
. But it is naturally supported on Channels Last memory format already. - Upstream - Will be easier since CPU doesn't hold secret ingredient and both inference and training will be covered.
On CNN models, memory format is all most the foundation of any upper level design. One imporant fact is converting memory format could be very expensive, so in case that multiple CNN operators are performed in a row e.g. Conv2d -> ReLU -> Conv2d
, it's beneficial to transform to the different memory format once, do computation and reorder them back.
On PyTorch, you can use 3 types of memory formats on CNN models:
### NB: internally sitll blocked format will be used.
### aka. we do 'reorder' for 'input', 'weight' and 'output',
### and believe me this is expensive, roughly 50% perf loss...
input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
output = model(input)
input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
### NB: convert to Channels Last memory format.
### oneDNN support NHWC for feature maps (input, output),
### but weight still need to be of blocked format.
### Still we can save reorders for feature maps.
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
output = model(input)
from torch.utils import mkldnn as mkldnn_utils
input = torch.randn(1, 10, 32, 32)
model = torch.nn.Conv2d(10, 20, 1, 1)
### NB: convert to blocked memory format.
### Note that 'output' is in blocked memory format,
### in case the subsequent operator doesn't support blocked memory format
### you need to manually reorder it back to NCHW by output.to_dense()
### mkldnn_utils.to_mkldnn(model) is used to prepack the weight, this will save weight reorder time
### for inference. For training, it is not needed.
input = input.to_mkldnn()
model = mkldnn_utils.to_mkldnn(model)
output = model(input)
Better to explain the concepts here with a diagram, the dotted line indicate a simple memory view, no hard copy.
Conclusion is that NHWC path saves the reorders from feature maps compared with NCHW path, but still weight reorder is necessary since oneDNN requires weight to be in blocked memory format. From performance perspective, when batch_size=N
, weight reorder is minimum compared with feature map reorder. But when batch_size=1
, weight reoder is usually not negligible. SO whether to enable weight prepacking on channels last memory format needs further discussion.
Before moving on, I feel it necessary to explain how PyTorch organize tensor in memory - the layout. Here we only focus on dense tensors, skip 'coo' layout of sparse tensor.
The question itself can be reinterpreted as for a tensor of size <N, C, H, W>, how does PyTorch accesses the element with index <n, w, h, w> from memory, the answer is stride:
tensor: <N, C, H, W>
index: <n, c, h, w>
strides: <CHW, HW, W, 1>
offset(n,c,h,w) = stride_n * n + stride_c * c + stride_h * h + stride_w * w
= CHW * n + HW * c + W * h + 1 * w
One merit of introducing stride is it will be able to express noncontiguous tensor, e.g. a slice of big tensor. For example, the 'Xs' in the following image will have a stride of <n1+n2, 1>.
Keep in mind that PyTorch Tensor does not have an attribute so called 'memory_format' or something. The memory format expression completely relies on size and stride, design principle can be found at reference: RFC: Memory format (aka layout aka NHWC) support. So no matter what the tensor's memory format is, we need a logical canonical order for the dimensions - that is NCHW on PyTorch. Thus size and stride are ALWAYs describes in the order of NCHW. OK let's take a look at the Channels Last case of the previous question:
tensor: <N, C, H, W>
index: <n, c, h, w>
strides: <HWC, 1, WC, C>
offset(n,c,h,w) = stride_n * n + stride_c * c + stride_h * h + stride_w * w
= HWC * n + 1 * c + WC * h + C * w
Actually, this pattern applies to ALL other memory formats as long as it is 4-dim, e.g. strides for CHWN would be <1, HWN, WN, N>.
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
### .contiguous() transforms NHWC noncontiguous to NHWC contiguous.
### .to() converts NCHW tensor to NHWC one, it is outplace.
x = x.contiguous(memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)
### contiguous check
x.is_contiguous(memory_format=torch.channels_last)
### NB: tensor.to() is an outplace operation
### model.to() is inplace. It calls _apply() which is inplace.
model = model.to(memory_format=torch.channels_last)
input = input.to(memory_format=torch.channels_last)
Detailed operator coverage information has been listed at reference Operators-with-Channels-Last-support. In brief, ImageNet training topologies on GPU already have full support on Channels Last memory format, while CPU doesn't.
Some spontaneous questions:
- How to tell whether this model or operator support Channels Last? - This requires mannual memory format check, aka. 'torch.channels_last' input and weight shall NOT generate 'torch.contiguous_format' output.
- What if the model comprises of operator not supported Channels Last? - No errors messages will be shown, the NHWC tensor will be handled by the operator as a non-contiguous NCHW tensor, so result might not be correct depending on the algorithm of this operator.
- No support - Requires to register Channels Last kernel for CPU path, e.g. Conv2d;
- Explicit support - Already have Channels Last kernel for CPU path (in ATen native manner), need to compare oneDNN counterpart performance, e.g. BatchNorm;
- Implicit support - Supported via meta structures like 'TensorIterator', need to compare oneDNN counterpart performance, e.g. ReLU.
The general guideline has been listed under reference Writing-memory-format-aware-operators, not to repeat here. You may take one of my recent PR optimize upsample performance linear mode on CPU as an example, which also demonstrates NHWC performance advantage over NCHW because of the ease of vectorization.
Essence of registering an oneDNN kernel under Channels Last memory format on CPU is no differenct from cuDNN: Only very few upper level change is needed such as accommodate 'contiguous()' to 'contiguous(suggested_memory_format)'. The automatic reorder of oneDNN weight shall been hided in ideep.
Compared to NCHW interfaces, 2 parts need to be addressed on NHWC inferfaces:
The logical size and stride description of oneDNN is always in NCHW, this is identical to PyTorch. Example code such as
/* create md from memory::format_tag */
auto src_md = memory::desc(
{N, C, H, W}, // logical dims, the order is defined by a primitive
memory::data_type::f32, // tensor's data type
memory::format_tag::nhwc // memory format, NHWC in this case
);
/* alternative: create md from strides */
auto src_md = memory::desc(
{N, C, H, W}, // logical dims, the order is defined by a primitive
memory::data_type::f32, // tensor's data type
{stride_N, stride_C, stride_H, stride_W} // the strides
);
/* create memory */
auto src_mem = memory(src_md, src_data_ptr, engine);
- NCHW - create
memory::desc
with any card for 'input', 'output' and 'weight'; query proposedmemory::desc
from convolution primitive; - NHWC - create
memory::desc
withformat_tag::nhwc
for 'input' and 'output', use any for 'weight'; if we usehwio
for 'weight' convolution primitive will be created with gemm rather jit avx512.
- User Experience - No special user level code change, only 'input' and 'model' conversion is required;
- Scenarios - cover both training and inference;
- Models - ResNet50 and ResNext101, extended targets: torchvision models, detectron2;
- Performance Targets - training >0.8x blocked; inference throughput > 0.8x blocked; inference latency? (need further discussion)
- Operator Converage - No less than GPU path;
- BFloat16 - This part shall align with big picture of BFloat16 integration (need further discussion);
- int8 - Need further discussion.
- oneDNN - upgrade to 1.5 or higher;
- ideep - interface change:
ideep::tensor
,ideep::computation
; - ATen integration - ConvNd shall PR directly; BatchNorm, Pooling, etc. need performance compare with native ATen Channels Last kernels; PR inference and training at the same time, one operator at a time; Traninig first; Inference weight prepacking under discussion;
- validation - oneDNN kernel level performance compare with NCHW and NHWC kernel; oneDNN NHWC kernel performance compare with native ATen Channels Last kernels; TTT measurement?
- distributed - gloo backend or ccl backend? or we compare with only 1S on CPU?
- Conv2d/ConvTransposed2d - Review
- AdaptiveAvgPool2d - Merged
- MaxPool2d - Merged
- AvgPool2d - Review
- BatchNorm2d - Review
- AdaptiveMaxPool2d - Review
- thnn_conv2d - Review
- GroupNorm - Review
- thnn_convtranpose2d - Review
- affine_grid/grid_sampler - Spatial transformer, needed??
- ChannelShuffle - review (both NCHW and NHWC need to be re-written, original impl not performant)
- ConstantPad2d - TODO
- FractionalMaxPool2d - TODO
- MaxUnpool2d - Review
- NativeDilatedConvolution -TODO
- PixelShuffle - Review
- ReflectionPad2d - TODO
- ReplicationPad2d - TODO
- RowwisePrune- Pruning, needed??
- UpsampleXXX2d - the dispatching logic has flaw, needs to refine
- weight_norm - Review
Play around with this script that mimics framework integraion
Obtained on 8180 with single thread
Takeaway: for this specific shape, NHWC performs better than NCHW in the naive path, i.e. just plain in plain out.