Skip to content

Instantly share code, notes, and snippets.

View mingfeima's full-sized avatar
:octocat:
i do not stand by in the presence of evil

Ma Mingfei mingfeima

:octocat:
i do not stand by in the presence of evil
  • Intel Asia-Pacific R&D
View GitHub Profile

PyTorch Performance Optimization on CPU

  1. pytorch mkldnn integration prototype design
  • mkldnn conv integration
  • conv3d parallelization: vol2col, col2vol
  • LSTM optimization non-fused: tanh/sigmoid parallelization
  1. Create MKLDNN conda channel

  2. MKLDNN tensor type

  • create lib/THMKL?
@mingfeima
mingfeima / mkldnn_integration_plan.md
Last active May 15, 2020 20:03
mkldnn integration plan, RFC draft

MKL-DNN Integration Plan

The purpose is to further improve PyTorch CPU performance on both imperative path and jit path. MKLDNN requires to reorder memory from plain layout to blocked layout to achieve optimal performance on CPU, e.g. from nchw to nChw16c, etc. At this moment on PyTorch, MKLDNN operators reuse CPU tensor, which means for each MKLDNN operator, it takes three steps to finish the computation:

input_reorder(plain_layout, blocked_layout)
mkldnn_computation()
output_reorder(blocked_layout, plain_layout)

These reorders takes about 50% of time on a typical ImageNet topology, e.g. ResNet50. Also MKLDNN chose different blocked format according to different input config from Convolution, with nn.Conv2d always output in plain layout, subsequent layers (BatchNorm, Pooling) would only execute on plain layout and this is the slow path for MKLDNN. With these problems solved, the CNN models would have 3~4x speedup v.s. current performance.

@mingfeima
mingfeima / [BKM] VTune.md
Last active May 22, 2019 01:42
vtune tips

Hotspot analysis:

/opt/intel/vtune_amplifier/bin64/amplxe-cl -collect hotspots -knob analyze-openmp=true -knob sampling-interval=10 --resume-after 5 -d 20 \
  -- /home/mingfeim/pytorch/unit_tests/run.sh
/opt/intel/vtune_amplifier/bin64/amplxe-cl -archive -r $1

Interpret vtune log function names: e.g.

@mingfeima
mingfeima / topk.md
Last active July 2, 2019 02:43
topk_optimization_backups

backups for PR19736 of topk() performance optimization on CPU.


description

Suppose input tensor has shape of [N, C], performance input.topk(K, sorted=Sorted) for the followings scenarios:

  1. C = 10000, 40000, 320000
  2. K = 10, 50, 100, C/10, C/2, C-5
  3. Test with 20 threads and 1 thread
  4. Test with Sorted=True and Sorted=False
@mingfeima
mingfeima / bert_optimization.md
Last active July 8, 2022 06:13
BERT Optimization

benchmark

Based on huggingface repo for performance evaluation, actual benchmark run script placed at repo. How to reproduce performance:

  1. prepare dataset according to link.
  2. update GLUE_DIR to actual dataset path in run_inference.sh.
  3. change env settings, the default setting is using 20 cores;

MKL v.s. MKLDNN

Inference performance result on Xeon 6148 (2x20 cores), single socket and single thread.

@mingfeima
mingfeima / rnn_perf_optimization.md
Last active May 10, 2023 10:58
MKLDNN RNN integration in PyTorch

This gist keeps a record of MKLDNN RNN integration job into PyTorch and serves a backup of PR26387, only inference feature is provided at the moment.

To use MKLDNN RNN in PyTorch:

  1. convert model to mkldnn
  2. (optional) convert input and hx/cx to mkldnn

example: how to enable mkl-dnn RNN

import torch
from torch.utils import mkldnn as mkldnn_utils
@mingfeima
mingfeima / pytorch_cpu_perf_bkm.md
Last active September 6, 2024 01:40
BKM for PyTorch CPU Performance

General guidelines for CPU performance on PyTorch

This file serves a BKM to get better performance on CPU for PyTorch, mostly focusing on inference or deployment. Chinese version available here.

1. Use channels last memory format

Right now, on PyTorch CPU path, you may choose to use 3 types of memory formats.

  • torch.contiguous_format: default memory format, also referred as NHCW.
  • torch.channels_last: also referred as NHWC.
  • torch._mkldnn: mkldnn blocked format.
@mingfeima
mingfeima / pytorch_check_mkl_mkldnn.md
Last active July 8, 2022 06:09
BKMs to check whether mkl or mkldnn is enabled on PyTorch

BKMs to check whether mkl or mkldnn is enabled on PyTorch

PyTorch can be installed via different channels: conda, pip, docker, source code...

By default, mkl and mkl-dnn are enabled; But this might not always be true, so it is still useful to learn how to check this by yourself:

1. How to check whether mkl is enabled?

### check where your torch is installed
python -c 'import torch; print(torch.__path__)'
@mingfeima
mingfeima / cat_perf_regression.md
Last active February 14, 2020 00:40
keep log of cat performance regression

trace #30806 of torch.cat() performance regression.

benchmark_all_test result, command line:

python -m benchmark_all_test --operators cat --tag_filter all

with commit 7b50e76255aebbbcdae702ee1f00d07d86b0112b

(pytorch-mingfei) [mingfeim@mlt-skx090 operator_benchmark]$ python -m benchmark_all_test --operators cat --tag_filter all