backups for PR19736 of topk() performance optimization on CPU.
Suppose input tensor has shape of [N, C], performance input.topk(K, sorted=Sorted) for the followings scenarios:
- C = 10000, 40000, 320000
- K = 10, 50, 100, C/10, C/2, C-5
- Test with 20 threads and 1 thread
- Test with Sorted=True and Sorted=False
- run_topk_scale.sh
CORES=`lscpu | grep Core | awk '{print $4}'`
SOCKETS=`lscpu | grep Socket | awk '{print $2}'`
TOTAL_CORES=`expr $CORES \* $SOCKETS`
LAST_CORE=`expr $CORES - 1`
KMP_SETTING="KMP_AFFINITY=granularity=fine,compact,1,0"
KMP_BLOCKTIME=1
PREFIX="numactl --physcpubind=0-$LAST_CORE --membind=0"
export $KMP_SETTING
export KMP_BLOCKTIME=$KMP_BLOCKTIME
echo -e "\n### using $KMP_SETTING"
echo -e "### using KMP_BLOCKTIME=$KMP_BLOCKTIME"
echo -e "### using $PREFIX\n"
### single socket test
echo -e "\n### using OMP_NUM_THREADS=$CORES"
OMP_NUM_THREADS=$CORES $PREFIX python -u test_topk.py
### single thread test
echo -e "\n### using OMP_NUM_THREADS=1"
OMP_NUM_THREADS=1 $PREFIX python -u test_topk.py- test_topk.py
import torch
from time import time
def bench_topk(N=8, C=168560, K=10, Sorted=True, num_iters=1000):
a = torch.randn(N, C)
for i in range(int(num_iters/10)):
torch.topk(a, K)
t = 0
for i in range(num_iters):
a = torch.randn(N, C)
start = time()
value, indice = torch.topk(a, K, sorted=Sorted)
t += time() - start
print("#[%d, %d], k=%d, sorted=%s times: %f ms" % (N, C, K,
('True' if Sorted else 'False'), t / num_iters * 1000))
def benc_topk_scale(Sorted):
Ns = [10]
Cs = [10000, 40000, 320000]
for n in Ns:
for c in Cs:
for k in [10, 50, 100, int(c/10), int(c/2), int(c-5)]:
iters = 500 if k > 5000 else 1000
bench_topk(n, c, k, Sorted, iters)
bench_topk_scale(True)
bench_topk_scale(False)Tested on Intel Xeon 6148, 20x2 cores @ 2.5GHz, to reproduce results:
./run_topk_scale.shAll numbers are reported in ms, oob refers original topk performance
Table-1: OMP=20, Sorted=True
| Input Size | oob | this pr | speed up |
|---|---|---|---|
| #[10, 10000], k=10 | 0.745 | 0.061 | 12.13 |
| #[10, 10000], k=50 | 0.765 | 0.077 | 9.93 |
| #[10, 10000], k=100 | 0.789 | 0.096 | 8.25 |
| #[10, 10000], k=1000 | 1.252 | 0.219 | 5.72 |
| #[10, 10000], k=5000 | 3.552 | 0.504 | 7.05 |
| #[10, 10000], k=9995 | 6.185 | 0.814 | 7.60 |
| #[10, 40000], k=10 | 2.882 | 0.176 | 16.37 |
| #[10, 40000], k=50 | 2.895 | 0.193 | 14.98 |
| #[10, 40000], k=100 | 2.914 | 0.222 | 13.10 |
| #[10, 40000], k=4000 | 5.213 | 0.835 | 6.24 |
| #[10, 40000], k=20000 | 15.811 | 2.019 | 7.83 |
| #[10, 40000], k=39995 | 27.980 | 3.231 | 8.66 |
| #[10, 320000], k=10 | 22.928 | 2.492 | 9.20 |
| #[10, 320000], k=50 | 22.835 | 2.498 | 9.14 |
| #[10, 320000], k=100 | 22.859 | 2.508 | 9.11 |
| #[10, 320000], k=32000 | 45.197 | 7.523 | 6.01 |
| #[10, 320000], k=160000 | 146.211 | 17.432 | 8.39 |
| #[10, 320000], k=319995 | 263.868 | 29.179 | 9.04 |
Table-2: OMP=20, Sorted=False
| Input Size | oob | this pr | speed up |
|---|---|---|---|
| #[10, 10000], k=10 | 0.746 | 0.061 | 12.20 |
| #[10, 10000], k=50 | 0.752 | 0.077 | 9.74 |
| #[10, 10000], k=100 | 0.756 | 0.096 | 7.88 |
| #[10, 10000], k=1000 | 0.847 | 0.172 | 4.93 |
| #[10, 10000], k=5000 | 1.038 | 0.186 | 5.58 |
| #[10, 10000], k=9995 | 0.848 | 0.171 | 4.95 |
| #[10, 40000], k=10 | 2.841 | 0.177 | 16.06 |
| #[10, 40000], k=50 | 2.866 | 0.189 | 15.18 |
| #[10, 40000], k=100 | 2.857 | 0.222 | 12.87 |
| #[10, 40000], k=4000 | 3.227 | 0.609 | 5.30 |
| #[10, 40000], k=20000 | 3.970 | 0.668 | 5.95 |
| #[10, 40000], k=39995 | 3.255 | 0.609 | 5.35 |
| #[10, 320000], k=10 | 22.597 | 2.487 | 9.09 |
| #[10, 320000], k=50 | 22.468 | 2.499 | 8.99 |
| #[10, 320000], k=100 | 22.553 | 2.517 | 8.96 |
| #[10, 320000], k=32000 | 25.606 | 5.480 | 4.67 |
| #[10, 320000], k=160000 | 32.419 | 6.124 | 5.29 |
| #[10, 320000], k=319995 | 28.623 | 6.005 | 4.77 |
Table-3: OMP=1, Sorted=True
| Input Size | oob | this pr | speed up |
|---|---|---|---|
| #[10, 10000], k=10 | 0.748 | 0.261 | 2.87 |
| #[10, 10000], k=50 | 0.766 | 0.391 | 1.96 |
| #[10, 10000], k=100 | 0.788 | 0.550 | 1.43 |
| #[10, 10000], k=1000 | 1.255 | 1.296 | 0.97 |
| #[10, 10000], k=5000 | 3.554 | 3.441 | 1.03 |
| #[10, 10000], k=9995 | 6.185 | 5.710 | 1.08 |
| #[10, 40000], k=10 | 2.877 | 0.933 | 3.08 |
| #[10, 40000], k=50 | 2.875 | 1.112 | 2.59 |
| #[10, 40000], k=100 | 2.895 | 1.282 | 2.26 |
| #[10, 40000], k=4000 | 5.184 | 5.304 | 0.98 |
| #[10, 40000], k=20000 | 15.905 | 15.106 | 1.05 |
| #[10, 40000], k=39995 | 27.970 | 25.741 | 1.09 |
| #[10, 320000], k=10 | 23.036 | 7.914 | 2.91 |
| #[10, 320000], k=50 | 22.857 | 8.181 | 2.79 |
| #[10, 320000], k=100 | 23.075 | 8.404 | 2.75 |
| #[10, 320000], k=32000 | 45.292 | 46.478 | 0.97 |
| #[10, 320000], k=160000 | 146.232 | 140.205 | 1.04 |
| #[10, 320000], k=319995 | 263.640 | 244.572 | 1.08 |
Table-4: OMP=1, Sorted=False
| Input Size | oob | this pr | speed up |
|---|---|---|---|
| #[10, 10000], k=10 | 0.747 | 0.260 | 2.87 |
| #[10, 10000], k=50 | 0.749 | 0.389 | 1.92 |
| #[10, 10000], k=100 | 0.758 | 0.548 | 1.38 |
| #[10, 10000], k=1000 | 0.845 | 0.933 | 0.91 |
| #[10, 10000], k=5000 | 1.037 | 1.132 | 0.92 |
| #[10, 10000], k=9995 | 0.848 | 0.951 | 0.89 |
| #[10, 40000], k=10 | 2.856 | 0.935 | 3.06 |
| #[10, 40000], k=50 | 2.863 | 1.090 | 2.63 |
| #[10, 40000], k=100 | 2.860 | 1.284 | 2.23 |
| #[10, 40000], k=4000 | 3.231 | 3.556 | 0.91 |
| #[10, 40000], k=20000 | 3.975 | 4.330 | 0.92 |
| #[10, 40000], k=39995 | 3.247 | 3.590 | 0.90 |
| #[10, 320000], k=10 | 22.570 | 7.943 | 2.84 |
| #[10, 320000], k=50 | 22.504 | 8.143 | 2.76 |
| #[10, 320000], k=100 | 22.489 | 8.407 | 2.68 |
| #[10, 320000], k=32000 | 25.558 | 29.042 | 0.88 |
| #[10, 320000], k=160000 | 32.357 | 36.160 | 0.89 |
| #[10, 320000], k=319995 | 28.541 | 31.501 | 0.91 |
std::partial_sort(heap sort) is quite fast whenKis small,std::sort(intro sort) is faster thenKis large.
- Use
std::partial_sortwhenKis of small range, no mattersortedisTrueorFalse - Use
std::nth_element+std::sortwhenKis of large range andsortedisTrue - Use
std::nth_elementwhenKis of large range andsortedisFalse
- remove
embrace_backand use pre-allocation forstd::vector - inline comparator lambda: gcc has trouble properly inline the lambda in case using condition expression (even with
-O3), e.g.auto comp = cond ? lambda1 : lambda2is marginally slower than written the lambda insidestd::sort.
- caffe: use
std::partial_sort, no parallelization. - mxnet: use
std::partial_sortwhen K < C/8, otherwise usestd::sort, parallel with OpenMP. - tensorflow: minimal heap, no inter parallelization.
- cntk: use
std::partial_sort, only parallelized when K=1
- when
Kis of small range andNsmaller than number of physical cores, parallel only onNdimension won't utilize all cores. For example, in transformer beam search, typical input size[N, C]andN < 10. Additional performance gain is possible:
- Step1: reorder
inputto be[N, S, C/S]and perform parallel topk onN * Sdimension. - Step2: output from step1 is
[N, S*K], sort onS*Kto find topk values on each channel.
NB: try this in mlperf transformer training...
- SIMD sort: some reference design: avx2-sort, avx512-sort.
NB: try avx512 quick select...
attach raw logs: original topk()