This Gist records optimization effort of DLRM on PyTorch CPU path.
Branch on track: dlrm
Task list:
- LAMB fused optimizer (fp32)
- Adagrad fused optimier (fp32)
- Split-SGD (bf16)
- Bucketize (bf16)
- Sum (bf16)
- LayerNorm (bf16)
- Softmax (bf16)
- cumsum (int64_t)
- tranposed copy (fp32/bf16)
- offset range (int64_t)
- sigmoid/sigmoid_backward (bf16)
LAMB optimizer - proposed in Papar Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
This implementation refers to fbgemm's gpu code at gpu_ref.
To use this CPU fused LAMB kernel, you need to cherry-pick cf5e826b and build from source.
### fused=True will use native C++ fused kernel from ATen
### fused=False will fallback to imperative torch impl, used for validation purposes
optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True)
Test case posted below as test_fused_lamb.py
, both contiguous and non-contiguous cases are tested. The weight tensor could be non-contiguous on occassion of explict fusion of multiple nn.Linear
modules.
The mnist from pytorch/examples converges as
Test set: Average loss: 0.0297, Accuracy: 9934/10000 (99%)
I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For single socket run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves 4.9x speedup:
### LAMB optimier bench:
unfused: 0.4495 ms; fused: 0.0923 ms
To reproduce the result (notice that jemalloc is applied):
./run.sh test_fused_lamb.py
[Notes]
- perf speedup primarily comes from: a) reduce of memory bandwidth of immediate tensors; b) the kernel has no additional memory allocation. For temp result of
adam_step
, it reuses the memory ofgrad
. So the kernel rewrites the gradient tensor since gradient is no longer used after the update of weight. - 4.9x perf speedup is tested on weight size of nn.Linear(1024, 1024). Speedup ratio would be greater if the weight tensor size is bigger.
- thread synchronization, the algorithm itself requires thread sync (e.g. norm of weight and adam_step). Ideally, we could do this with
#pragma omp barrier
thus we can finish the whole computation within a single omp session. But this would trigger a bug: PyTorch omp wrapperat::parallel
will not make sure all omp threads in the same TEAM to be used (N=64 will launch 16 threads even the #cores is 20), so the un-used thread will never reach the barrier and keep on waiting. So i break the code into 2 omp sessions.
### fused=True will use native C++ fused kernel from ATen
### fused=False will fallback to imperative torch impl, used for validation purposes
optimizer = optim.Adagrad(model.parameters(), lr=0.01, fused=True)
The mnist from pytorch/examples converges as
Test set: Average loss: 0.0363, Accuracy: 9881/10000 (99%)
I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For single socket run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves 3.2x speedup:
### ADAGRAD optimier bench:
unfused: 0.1022 ms; fused: 0.0321 ms
To reproduce the result (notice that jemalloc is applied):
./run.sh test_fused_adagrad.py
Basic idea of the algorithm is to store a copy of master weight in fp32 by splitting the upper 16 bits and lower 16 bits. The lower half is stored in optimizer as a state. So the weight could be updated in fp32 through packing and unpacking.
The usage is identical to normal fp32 fused kernel, with fused=True
, parameter with data type torch.bfloat16
would automatically use split sgd algorithm:
### fused=True will use native C++ fused kernel from ATen
### fused=False will fallback to imperative torch impl, used for validation purposes
optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True)
### LAMB
unfused (fp32): 0.4526 ms; fused (fp32): 0.0940 ms; split fused (bf16): 0.0879 ms
python test_optim.py TestSplitSGD.test_lamb_bfloat16_cpu
python test_optim.py TestSplitSGD.test_adagrad_bfloat16_cpu
[Notes]: Known issue: this impl is expected to have runtime error on AVX machine, make sure you have AVX2+ CPU. (I did not register the AVX kernels)
BFloat16 is not an actual data type, we need to handle BFloat16 operator in the following manner:
- input/output: load: bf16->fp32; store: fp32->bf16
- immediate operations (including accumulation): use fp32
We have multiple ways to enable BFloat16 OP on PyTorch, namely:
- Naive Impl: add
kBFloat16
toAT_DISPATCH_FLOATING_TYPES
macro, since on PyTorch both scalar and Vec256<> logic has specialization forBFloat16
, this could run smoothly. But this naive impl is not good. - Funtional Specialization: specialize
vec256::Map<>
from functional.cpp withBFloat16
. Similar to oneDNN implementation. - Cache FP32 Data: Convert bf16 data to fp32 per input row and cache (possibly) in L1. Similar to cuda counterpart implementation.
Consider the following example:
using Vec = Vec256<BFloat16>;
Vec one = Vec(BFloat16(1));
vec256::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
Impl-1 will end up with 3 pairs of dtype conversion, each for ".exp()", "+" and "/". Both Impl-2 and Impl-3 will only need dtype conversion for input and output. Benefits:
- better performance since we have less dtype conversion;
- less rounding error since immediate results are kept in fp32;
- accumulation done on data type of fp32.
For Impl-2 and Impl-3, with emulated dtype conversion Impl-3 is faster for most cases; with native conversion assembly, Impl-2 is faster. So I follow Impl-2 in these patches.
Naive Impl:
Softmax: 128x1024: fp32: 150.324 us; bf16: 356.587 us
tensor max (abs) diff: 2.9515125788748264e-05
Funtional Specialization:
log_softmax: 128x1024: fp32: 150.132 us; bf16: 194.974 us
tensor max (abs) diff: 1.509662251919508e-05
Test:
cd pytorch/build/bin/
vec256_test_all_types_AVX vec256_test_all_types_AVX2 vec256_test_all_types_DEFAULT
python test_nn.py TestNN.test_log_softmax_cpu
python test_nn.py TestNN.test_softmax_cpu
Naive Impl:
sum size: 128x30678, fp32: 0.588 ms; bf16: 0.899 ms
Funtional Specialization:
sum size: 128x30678, fp32: 0.590 ms; bf16: 0.335 ms
Naive Impl:
LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.806 ms; bf16: 9.901 ms
tensor max (abs) diff: 0.1355377435684204
Funtional Specialization:
LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.813 ms; bf16: 2.306 ms
tensor max (abs) diff: 0.04277598857879639
Test
python test_nn.py TestNNDeviceTypeCPU.test_LayerNorm_general_cpu
Benchmark script launcher - run.sh
Testing case and benchmark - test_fused_lamb.py
Testing case and benchmark - test_fused_adagrad.py