This file serves a BKM to get better performance on CPU for PyTorch, mostly focusing on inference or deployment. Chinese version available here.
Right now, on PyTorch CPU path, you may choose to use 3 types of memory formats.
- torch.contiguous_format: default memory format, also referred as NHCW.
- torch.channels_last: also referred as NHWC.
- torch._mkldnn: mkldnn blocked format.
The default NCHW has worse performance compared with NHWC and MKLDNN Blocked memory format.
### 1. default (NCHW)
output = model(input)
### 2. channels last
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
### Note: Most of CV model has Convolution as 1st layer and channels last has higher priority in Conv2d.
### So you can just also only convert weight to channels last and input will be converted accordingly.
### And channels last memory format will be propagated through the model (until operator without channels
### last support, if any).
### 3a. mkldnn blocked format (inference)
input = input.to_mkldnn()
model = torch.utils.mkldnn.to_mkldnn(model)
output = model(input)
### 3b. mkldnn blocked format (training)
input = input.to_mkldnn()
output = model(input)
In case the model has operators which doestn't support channels last memory format, you might not be able to get optimal performance since NHWC will be treated as non-contiguous of NCHW and the rest of the model will propagate NCHW.
In case the model has operators which doesn't support mkldnn blocked memory format, you need to inserts to_dense()
and to_mkldnn()
in between:
class MyModel(nn.Module):
def __init__(self):
self(MyModel, self).__init__()
self.conv1 = nn.Conv2d(10, 10, 3)
# MyModel has mkldnn unsupported operators X()
self.unsupported_mod = nn.X()
self.linear1 = nn.Linear(10, 20)
def forward(self, x):
x = self.conv1(x)
# use default layout for module without mkldnn support
x = x.to_dense()
x = self.unsupported_mod(x)
x = x.to_mkldnn()
x = self.linear1(x)
return x
Further explaination on Channels Last memory format optimization on PyTorch Channels Last Memory Format Performance Optimization on CPU Path.
Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores available here.
NHWC performance is collected with: torch-opt-test and torchvision-opt-test. Upsteaming to public is ongoing.
### NCHW run
Running on torch: 1.8.1+cpu
Running on torchvision: 0.9.1+cpu
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn :forward: 55.89 (ms) 17.89 (imgs/s)
nn :backward: 0.00 (ms)
nn :update: 0.00 (ms)
nn :total: 55.89 (ms) 17.89 (imgs/s)
### NHWC run
Running on torch: 1.9.0a0+git850a6bd
Running on torchvision: 0.10.0a0+4f34ae5
ModelType: resnet50, Kernels: nn Input shape: 1x3x224x224
nn :forward: 14.02 (ms) 71.31 (imgs/s)
nn :backward: 0.00 (ms)
nn :update: 0.00 (ms)
nn :total: 14.02 (ms) 71.31 (imgs/s)
If you model need csrc modules from torchvision, e.g. "ROIAlign", torchvision also needs to has channels last support. For example, MaskedRCNN related info has been placed at here.
Results on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, single socket with 20 cores:
### with config "fast_rcnn_R_50_FPN_1x.yaml"
### NCHW (torch-1.8.1/vision-0.9.1): 300 iters in 326.0195782049559 seconds.
### NCHW (torch-opt/vision-0.9.1): 300 iters in 185.4384527085349 seconds.
### NCHW (torch-opt/vision-opt): 300 iters in 80.56146793198423 seconds.
### NHWC (torch-opt/vision-opt): 300 iters in 55.49435344198719 seconds.
Upstreaming to public pytorch repo is ongoing. Further optimization is also WIP.
For single instance run, regulate omp thread count and core biding as:
export OMP_NUM_THREADS=[number_of_physical_cores]
export KMP_AFFINITY=granularity=fine,compact,1,0
For single socket run, avoid remote memory access by numactrl
# e.g. say each socket has 20 cores, to use the 1st socket:
numactl --physcpubind=0-19 --membind=0
For multi instance run, in case each instance will spawn its own omp thread pool, regulate OMP_NUM_THREADS
per instance.
Make sure omp_threads
* num_instances
do not exceed number of physical cores, so as to prevent over subscription.
The multi instance case is much more complicated than single instance, since there exists numbers of upper level of threading model, you may use torch.multiprocessing
, std::threads
, TBB
, etc. Be careful with over subscription, this is going to result in dramatic performance drop on CPU. Easiest way to determine such issue on Intel CPU is vtune.
PyTorch uses dynamic graph which has a flaw that output
of each operator must be allocated for each execution, which increases the burden of memory allocation and will trigger clear page for large buffer. This issue can be alleviated with jemalloc or tcmalloc to some extend.
### jemalloc
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.1/lib/libjemalloc.so
### tcmalloc
export LD_PRELOAD=/home/mingfeim/packages/gperftools-2.8/install/lib/libtcmalloc.so
If you see clear_page from vmlinux.so is consuming a lot of time from vtune, it is time to apply jemalloc.
torch.utils.data.DataLoader
may be slower in case num_workers > 0
, try to compare with num_workers = 0
.
@adelelwan24 Hi, you can use
torch.backends.mkldnn.verbose(2)
or directly use an environment variable in your shellexport DNNL_VERBOSE=2
to get the oneDNN (mkldnn is the old name for oneDNN).From your log it tells:
avx10.1/512
with AMX extension, data type is bf16brgconv
indicates the algorithm used, b-r is batch reducedABCD
staff indicates the memory layouts: ABCD is NCHW and ACDB is NHWCBy default, oneDNN will use the latest ISA available on your machine to generate the runtime kernels with jit. The order is documented in the link you provided: https://oneapi-src.github.io/oneDNN/dev_guide_cpu_dispatcher_control.html
If you intend to use an older ISA on a new machine, for example, run avx2 on a CPU with avx512, you can use
export DNNL_MAX_CPU_ISA=AVX2
. This is mostly used for debugging purposes, for example, if a primitive gets different results on different ISAs, then it is an oneDNN internal error.