Skip to content

Instantly share code, notes, and snippets.

@kalomaze
Created August 25, 2024 02:03
Show Gist options
  • Save kalomaze/f03145c5c64c75295c8fd9de7a5611c0 to your computer and use it in GitHub Desktop.
Save kalomaze/f03145c5c64c75295c8fd9de7a5611c0 to your computer and use it in GitHub Desktop.
Tensor Parallel latency
== Results torch.int8 meta-llama/Llama-2-7b-hf-TP1 ====
[--------------------------------------- scaled-torch.int8-gemm --------------------------------------]
| pytorch_bf16_bf16_bf16_matmul-no-scales | cutlass_i8_i8_bf16_scaled_mm
1 threads: --------------------------------------------------------------------------------------------
MKN=(1x4096x12288) | 195.3 | 142.4
MKN=(1x4096x4096) | 64.5 | 47.5
MKN=(1x4096x22016) | 322.9 | 235.6
MKN=(1x11008x4096) | 162.6 | 112.9
MKN=(16x4096x12288) | 187.5 | 142.6
MKN=(16x4096x4096) | 66.2 | 47.6
MKN=(16x4096x22016) | 331.4 | 237.2
MKN=(16x11008x4096) | 168.2 | 113.1
MKN=(32x4096x12288) | 206.4 | 142.9
MKN=(32x4096x4096) | 66.1 | 47.8
MKN=(32x4096x22016) | 362.7 | 238.9
MKN=(32x11008x4096) | 169.3 | 113.3
MKN=(64x4096x12288) | 207.9 | 143.5
MKN=(64x4096x4096) | 73.1 | 48.2
MKN=(64x4096x22016) | 428.4 | 242.2
MKN=(64x11008x4096) | 179.3 | 113.8
MKN=(128x4096x12288) | 224.3 | 145.0
MKN=(128x4096x4096) | 71.3 | 48.8
MKN=(128x4096x22016) | 435.7 | 248.1
MKN=(128x11008x4096) | 179.8 | 114.4
MKN=(256x4096x12288) | 340.8 | 176.5
MKN=(256x4096x4096) | 129.3 | 58.4
MKN=(256x4096x22016) | 571.7 | 306.6
MKN=(256x11008x4096) | 335.7 | 136.1
MKN=(512x4096x12288) | 497.0 | 297.3
MKN=(512x4096x4096) | 164.3 | 110.4
MKN=(512x4096x22016) | 877.7 | 529.1
MKN=(512x11008x4096) | 430.7 | 261.4
Times are in microseconds (us).
== Results torch.int8 meta-llama/Llama-2-7b-hf-TP2 ====
[--------------------------------------- scaled-torch.int8-gemm --------------------------------------]
| pytorch_bf16_bf16_bf16_matmul-no-scales | cutlass_i8_i8_bf16_scaled_mm
1 threads: --------------------------------------------------------------------------------------------
MKN=(1x4096x6144) | 93.8 | 63.6
MKN=(1x2048x4096) | 34.8 | 28.0
MKN=(1x4096x11008) | 171.2 | 140.7
MKN=(1x5504x4096) | 84.8 | 60.6
MKN=(16x4096x6144) | 105.2 | 63.7
MKN=(16x2048x4096) | 34.8 | 27.9
MKN=(16x4096x11008) | 176.2 | 141.4
MKN=(16x5504x4096) | 87.3 | 60.6
MKN=(32x4096x6144) | 99.7 | 64.2
MKN=(32x2048x4096) | 36.2 | 28.1
MKN=(32x4096x11008) | 205.7 | 142.1
MKN=(32x5504x4096) | 87.7 | 60.9
MKN=(64x4096x6144) | 105.6 | 65.2
MKN=(64x2048x4096) | 36.2 | 28.3
MKN=(64x4096x11008) | 207.7 | 143.5
MKN=(64x5504x4096) | 94.1 | 61.2
MKN=(128x4096x6144) | 120.5 | 66.5
MKN=(128x2048x4096) | 39.4 | 29.1
MKN=(128x4096x11008) | 203.7 | 146.2
MKN=(128x5504x4096) | 92.2 | 62.4
MKN=(256x4096x6144) | 151.4 | 108.4
MKN=(256x2048x4096) | 70.7 | 36.2
MKN=(256x4096x11008) | 339.5 | 174.2
MKN=(256x5504x4096) | 172.4 | 74.7
MKN=(512x4096x6144) | 273.9 | 165.2
MKN=(512x2048x4096) | 86.4 | 67.7
MKN=(512x4096x11008) | 491.2 | 285.1
MKN=(512x5504x4096) | 219.6 | 140.8
Times are in microseconds (us).
== Results torch.int8 meta-llama/Llama-2-7b-hf-TP4 ====
[-------------------------------------- scaled-torch.int8-gemm --------------------------------------]
| pytorch_bf16_bf16_bf16_matmul-no-scales | cutlass_i8_i8_bf16_scaled_mm
1 threads: -------------------------------------------------------------------------------------------
MKN=(1x4096x3072) | 50.4 | 46.9
MKN=(1x1024x4096) | 20.1 | 16.8
MKN=(1x4096x5504) | 84.7 | 58.0
MKN=(1x2752x4096) | 45.3 | 34.5
MKN=(16x4096x3072) | 50.1 | 46.9
MKN=(16x1024x4096) | 19.9 | 16.9
MKN=(16x4096x5504) | 104.1 | 58.3
MKN=(16x2752x4096) | 45.5 | 34.6
MKN=(32x4096x3072) | 55.8 | 47.1
MKN=(32x1024x4096) | 21.1 | 17.1
MKN=(32x4096x5504) | 90.4 | 58.7
MKN=(32x2752x4096) | 46.1 | 34.7
MKN=(64x4096x3072) | 68.9 | 47.1
MKN=(64x1024x4096) | 20.7 | 17.6
MKN=(64x4096x5504) | 97.4 | 59.7
MKN=(64x2752x4096) | 51.9 | 35.3
MKN=(128x4096x3072) | 73.3 | 47.2
MKN=(128x1024x4096) | 22.6 | 18.2
MKN=(128x4096x5504) | 131.7 | 60.8
MKN=(128x2752x4096) | 51.0 | 35.4
MKN=(256x4096x3072) | 81.8 | 51.7
MKN=(256x1024x4096) | 38.2 | 26.5
MKN=(256x4096x5504) | 147.4 | 107.8
MKN=(256x2752x4096) | 91.4 | 44.1
MKN=(512x4096x3072) | 167.8 | 101.0
MKN=(512x1024x4096) | 46.8 | 49.5
MKN=(512x4096x5504) | 246.3 | 161.2
MKN=(512x2752x4096) | 113.0 | 82.3
Times are in microseconds (us).
== Results torch.int8 meta-llama/Llama-2-7b-hf-TP8 ====
[-------------------------------------- scaled-torch.int8-gemm --------------------------------------]
| pytorch_bf16_bf16_bf16_matmul-no-scales | cutlass_i8_i8_bf16_scaled_mm
1 threads: -------------------------------------------------------------------------------------------
MKN=(1x4096x1536) | 25.6 | 46.1
MKN=(1x512x4096) | 8.6 | 11.1
MKN=(1x4096x2752) | 47.0 | 46.7
MKN=(1x1376x4096) | 25.2 | 21.2
MKN=(16x4096x1536) | 29.5 | 46.2
MKN=(16x512x4096) | 8.5 | 11.2
MKN=(16x4096x2752) | 45.6 | 46.7
MKN=(16x1376x4096) | 25.4 | 21.3
MKN=(32x4096x1536) | 34.0 | 46.3
MKN=(32x512x4096) | 8.5 | 11.3
MKN=(32x4096x2752) | 50.7 | 46.9
MKN=(32x1376x4096) | 26.4 | 21.5
MKN=(64x4096x1536) | 37.8 | 46.4
MKN=(64x512x4096) | 8.4 | 11.4
MKN=(64x4096x2752) | 62.8 | 47.0
MKN=(64x1376x4096) | 26.4 | 21.7
MKN=(128x4096x1536) | 31.5 | 46.7
MKN=(128x512x4096) | 9.9 | 11.9
MKN=(128x4096x2752) | 62.8 | 46.9
MKN=(128x1376x4096) | 29.3 | 22.3
MKN=(256x4096x1536) | 45.4 | 46.6
MKN=(256x512x4096) | 16.5 | 21.1
MKN=(256x4096x2752) | 100.2 | 50.1
MKN=(256x1376x4096) | 54.9 | 29.6
MKN=(512x4096x1536) | 101.4 | 50.4
MKN=(512x512x4096) | 27.3 | 39.2
MKN=(512x4096x2752) | 165.5 | 101.1
MKN=(512x1376x4096) | 60.6 | 56.0
Times are in microseconds (us).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment