Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save davidberard98/505cfd37af4334db03a6ecbac79ffec4 to your computer and use it in GitHub Desktop.
Save davidberard98/505cfd37af4334db03a6ecbac79ffec4 to your computer and use it in GitHub Desktop.
triton_output_with_fp16_inputs=tensor([[-17.7500, 12.5938, -28.2500, ..., -23.9531, 2.9141, -9.3359],
[ 5.3750, -24.9844, 7.1016, ..., 2.7383, -42.6562, 1.9766],
[-22.8906, -5.9766, -8.2031, ..., -0.2485, -41.5312, 19.0938],
...,
[ 5.4648, -2.0977, 18.4531, ..., -36.9688, -7.6680, -20.1719],
[ -9.2031, -12.2812, -20.5312, ..., -24.5625, -50.9062, -3.6387],
[ 44.2188, -7.1328, -28.3750, ..., 4.6914, 7.9648, -8.6641]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-17.7500, 12.5938, -28.2500, ..., -23.9531, 2.9141, -9.3359],
[ 5.3750, -24.9844, 7.1016, ..., 2.7383, -42.6562, 1.9766],
[-22.8906, -5.9766, -8.2031, ..., -0.2485, -41.5312, 19.0938],
...,
[ 5.4648, -2.0977, 18.4531, ..., -36.9688, -7.6680, -20.1719],
[ -9.2031, -12.2812, -20.5312, ..., -24.5625, -50.9062, -3.6387],
[ 44.2188, -7.1328, -28.3750, ..., 4.6914, 7.9648, -8.6641]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
triton_output_with_fp8_inputs=tensor([[-23.5312, 13.2500, 8.5703, ..., -13.9297, -4.4258, 33.1562],
[ 7.8984, 34.5312, -5.1758, ..., -24.3281, 5.9062, -44.2812],
[ 20.3594, -4.2500, -20.5625, ..., -43.0000, 0.3276, -22.6250],
...,
[-46.7188, 16.9062, -22.2344, ..., 30.7344, -6.5781, 5.5703],
[ -2.3770, 17.4375, -1.1807, ..., -18.4375, 2.1602, 34.6875],
[ -4.8633, -15.5547, 9.5234, ..., -0.4426, -10.0938, 3.5762]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-23.5312, 13.2500, 8.5703, ..., -13.9297, -4.4258, 33.1562],
[ 7.8984, 34.5312, -5.1758, ..., -24.3281, 5.9062, -44.2812],
[ 20.3594, -4.2500, -20.5625, ..., -43.0000, 0.3276, -22.6250],
...,
[-46.7188, 16.9062, -22.2344, ..., 30.7344, -6.5781, 5.5703],
[ -2.3770, 17.4375, -1.1807, ..., -18.4375, 2.1602, 34.6875],
[ -4.8633, -15.5547, 9.5234, ..., -0.4426, -10.0938, 3.5762]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
matmul-performance-fp16:
M N K cuBLAS Triton
0 256.0 256.0 256.0 5.461333 5.216796
1 384.0 384.0 384.0 11.059200 13.824000
2 512.0 512.0 512.0 18.724571 26.214401
3 640.0 640.0 640.0 25.600001 36.571428
4 768.0 768.0 768.0 27.648000 40.215272
5 896.0 896.0 896.0 46.830935 54.759677
6 1024.0 1024.0 1024.0 65.536000 64.965022
7 1152.0 1152.0 1152.0 46.656000 64.912694
8 1280.0 1280.0 1280.0 58.514284 73.142856
9 1408.0 1408.0 1408.0 71.733898 69.894567
10 1536.0 1536.0 1536.0 58.982401 70.778882
11 1664.0 1664.0 1664.0 69.222400 80.347427
12 1792.0 1792.0 1792.0 80.281598 80.263686
13 1920.0 1920.0 1920.0 71.999997 78.545454
14 2048.0 2048.0 2048.0 81.442797 79.891505
15 2176.0 2176.0 2176.0 73.443973 79.855747
16 2304.0 2304.0 2304.0 81.992053 82.005246
17 2432.0 2432.0 2432.0 78.040177 82.320564
18 2560.0 2560.0 2560.0 86.459102 85.049558
19 2688.0 2688.0 2688.0 72.116078 85.682066
20 2816.0 2816.0 2816.0 79.298560 88.697466
21 2944.0 2944.0 2944.0 86.563154 85.335668
22 3072.0 3072.0 3072.0 83.357378 86.315709
23 3200.0 3200.0 3200.0 82.419461 89.887639
24 3328.0 3328.0 3328.0 89.098141 87.794262
25 3456.0 3456.0 3456.0 80.483233 88.692595
26 3584.0 3584.0 3584.0 86.540320 87.381330
27 3712.0 3712.0 3712.0 79.926869 86.587255
28 3840.0 3840.0 3840.0 85.465227 90.279183
29 3968.0 3968.0 3968.0 80.211296 89.636975
30 4096.0 4096.0 4096.0 85.434583 89.197881
matmul-performance-fp8:
M N K cuBLAS Triton
0 256.0 256.0 256.0 5.461333 5.461333
1 384.0 384.0 384.0 14.327709 13.824000
2 512.0 512.0 512.0 32.768000 26.214401
3 640.0 640.0 640.0 51.200001 36.571428
4 768.0 768.0 768.0 64.491007 44.236801
5 896.0 896.0 896.0 79.290468 58.538665
6 1024.0 1024.0 1024.0 114.422619 65.536000
7 1152.0 1152.0 1152.0 106.642284 71.094858
8 1280.0 1280.0 1280.0 136.533337 75.851852
9 1408.0 1408.0 1408.0 129.804192 80.173175
10 1536.0 1536.0 1536.0 131.072000 76.933564
11 1664.0 1664.0 1664.0 155.153651 83.323259
12 1792.0 1792.0 1792.0 156.103106 81.611433
13 1920.0 1920.0 1920.0 157.090908 86.400002
14 2048.0 2048.0 2048.0 158.275623 86.480498
15 2176.0 2176.0 2176.0 160.068783 86.739860
16 2304.0 2304.0 2304.0 163.615555 87.182017
17 2432.0 2432.0 2432.0 167.228954 88.347369
18 2560.0 2560.0 2560.0 170.666661 89.530056
19 2688.0 2688.0 2688.0 163.504548 88.628636
20 2816.0 2816.0 2816.0 178.769648 90.538740
21 2944.0 2944.0 2944.0 172.872962 88.992912
22 3072.0 3072.0 3072.0 167.663488 90.887803
23 3200.0 3200.0 3200.0 181.818180 91.446943
24 3328.0 3328.0 3328.0 168.993647 90.587138
25 3456.0 3456.0 3456.0 174.884100 91.734322
26 3584.0 3584.0 3584.0 172.914215 91.163187
27 3712.0 3712.0 3712.0 171.644919 90.101605
28 3840.0 3840.0 3840.0 183.184640 91.832370
29 3968.0 3968.0 3968.0 182.125283 91.648265
30 4096.0 4096.0 4096.0 181.375311 91.180520
triton_output_with_fp16_inputs=tensor([[-17.7500, 12.5938, -28.2500, ..., -23.9531, 2.9141, -9.3359],
[ 5.3750, -24.9844, 7.1016, ..., 2.7383, -42.6562, 1.9766],
[-22.8906, -5.9766, -8.2031, ..., -0.2485, -41.5312, 19.0938],
...,
[ 5.4648, -2.0977, 18.4531, ..., -36.9688, -7.6680, -20.1719],
[ -9.2031, -12.2812, -20.5312, ..., -24.5625, -50.9062, -3.6387],
[ 44.2188, -7.1328, -28.3750, ..., 4.6914, 7.9648, -8.6641]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-17.7500, 12.5938, -28.2500, ..., -23.9531, 2.9141, -9.3359],
[ 5.3750, -24.9844, 7.1016, ..., 2.7383, -42.6562, 1.9766],
[-22.8906, -5.9766, -8.2031, ..., -0.2485, -41.5312, 19.0938],
...,
[ 5.4648, -2.0977, 18.4531, ..., -36.9688, -7.6680, -20.1719],
[ -9.2031, -12.2812, -20.5312, ..., -24.5625, -50.9062, -3.6387],
[ 44.2188, -7.1328, -28.3750, ..., 4.6914, 7.9648, -8.6641]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
triton_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., -14.4375, -4.6836, 29.9688],
[ 10.0000, 37.0000, -5.5664, ..., -25.2344, 4.5859, -44.3125],
[ 19.5625, -3.0078, -20.0469, ..., -43.5312, 0.3984, -25.5156],
...,
[-46.9375, 14.7734, -18.9062, ..., 24.7188, -7.2773, 5.3867],
[ -0.9683, 15.9688, -2.1523, ..., -21.9688, 1.3428, 33.5625],
[ -2.8809, -16.5000, 4.0898, ..., 0.8066, -6.4258, 3.4766]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-21.4375, 13.1719, 6.0352, ..., -14.4375, -4.6836, 29.9688],
[ 10.0000, 37.0000, -5.5664, ..., -25.2344, 4.5859, -44.3125],
[ 19.5625, -3.0078, -20.0469, ..., -43.5312, 0.3984, -25.5156],
...,
[-46.9375, 14.7734, -18.9062, ..., 24.7188, -7.2773, 5.3867],
[ -0.9683, 15.9688, -2.1523, ..., -21.9688, 1.3428, 33.5625],
[ -2.8809, -16.5000, 4.0898, ..., 0.8066, -6.4258, 3.4766]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
matmul-performance-fp16:
M N K cuBLAS Triton
0 256.0 256.0 256.0 5.461333 5.216796
1 384.0 384.0 384.0 11.059200 13.824000
2 512.0 512.0 512.0 18.724571 26.214401
3 640.0 640.0 640.0 25.600001 36.571428
4 768.0 768.0 768.0 27.648000 40.215272
5 896.0 896.0 896.0 46.830935 57.860614
6 1024.0 1024.0 1024.0 65.536000 64.965022
7 1152.0 1152.0 1152.0 46.656000 64.912694
8 1280.0 1280.0 1280.0 58.514284 73.142856
9 1408.0 1408.0 1408.0 71.733898 69.894567
10 1536.0 1536.0 1536.0 58.015477 70.778882
11 1664.0 1664.0 1664.0 69.222400 80.347427
12 1792.0 1792.0 1792.0 80.281598 79.150871
13 1920.0 1920.0 1920.0 71.257735 78.545454
14 2048.0 2048.0 2048.0 81.442797 79.891505
15 2176.0 2176.0 2176.0 72.986143 79.855747
16 2304.0 2304.0 2304.0 81.807778 81.807778
17 2432.0 2432.0 2432.0 77.952214 82.147552
18 2560.0 2560.0 2560.0 86.231576 84.891192
19 2688.0 2688.0 2688.0 72.116078 85.434810
20 2816.0 2816.0 2816.0 79.011245 88.646760
21 2944.0 2944.0 2944.0 86.520886 85.044426
22 3072.0 3072.0 3072.0 83.146995 85.877975
23 3200.0 3200.0 3200.0 82.051282 89.385477
24 3328.0 3328.0 3328.0 88.659230 87.368079
25 3456.0 3456.0 3456.0 80.061141 88.304015
26 3584.0 3584.0 3584.0 86.043434 86.982457
27 3712.0 3712.0 3712.0 79.472826 86.192706
28 3840.0 3840.0 3840.0 85.005380 89.859682
29 3968.0 3968.0 3968.0 79.702113 89.264032
30 4096.0 4096.0 4096.0 84.909301 88.902474
matmul-performance-fp8:
M N K Triton
0 256.0 256.0 256.0 5.461333
1 384.0 384.0 384.0 13.824000
2 512.0 512.0 512.0 26.214401
3 640.0 640.0 640.0 36.571428
4 768.0 768.0 768.0 44.236801
5 896.0 896.0 896.0 58.538665
6 1024.0 1024.0 1024.0 65.536000
7 1152.0 1152.0 1152.0 71.094858
8 1280.0 1280.0 1280.0 75.851852
9 1408.0 1408.0 1408.0 80.173175
10 1536.0 1536.0 1536.0 76.933564
11 1664.0 1664.0 1664.0 83.323259
12 1792.0 1792.0 1792.0 81.445100
13 1920.0 1920.0 1920.0 86.400002
14 2048.0 2048.0 2048.0 86.341416
15 2176.0 2176.0 2176.0 86.634836
16 2304.0 2304.0 2304.0 87.182017
17 2432.0 2432.0 2432.0 88.347369
18 2560.0 2560.0 2560.0 89.530056
19 2688.0 2688.0 2688.0 88.628636
20 2816.0 2816.0 2816.0 90.485908
21 2944.0 2944.0 2944.0 88.720610
22 3072.0 3072.0 3072.0 90.742157
23 3200.0 3200.0 3200.0 91.334794
24 3328.0 3328.0 3328.0 90.214654
25 3456.0 3456.0 3456.0 91.407671
26 3584.0 3584.0 3584.0 90.889618
27 3712.0 3712.0 3712.0 89.835744
28 3840.0 3840.0 3840.0 91.473945
29 3968.0 3968.0 3968.0 91.403695
30 4096.0 4096.0 4096.0 90.933423
TMA benchmarks will be running with experimental grid constant TMA descriptor.
fused-attention-batch4-head32-d64-fwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 55.643018 53.744011
1 2048.0 65.710240 60.274943
2 4096.0 71.499191 64.313866
3 8192.0 74.618951 66.866475
4 16384.0 76.327836 68.260768
fused-attention-batch4-head32-d64-fwd-causal=False:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 70.472140 58.175808
1 2048.0 72.946867 60.506134
2 4096.0 75.205399 68.400012
3 8192.0 75.743638 68.771560
4 16384.0 76.213386 69.202232
fused-attention-batch4-head32-d64-bwd-causal=True:
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 30.488678 30.431360
1 2048.0 37.188401 37.177122
2 4096.0 41.899225 41.887539
3 8192.0 44.713159 44.716139
4 16384.0 46.197745 46.199733

a bunch of perf numbers for 50-series on triton tutorials

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment