Created
January 2, 2024 08:32
-
-
Save 152334H/181357c847830b1bd4f33a18aa205e08 to your computer and use it in GitHub Desktop.
Demonstrating the 2x FLOPs in gamer GPUs when FP16 accumulators are used.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
savepath = Path('mm') | |
savepath.mkdir(exist_ok=True) | |
import torch | |
import triton | |
from triton.ops.matmul import matmul as triton_matmul | |
#matmul = lambda a,b: _matmul.forward(a,b, acc_dtype=torch.float16, allow_tf32=True, output_dtype=torch.float16) # nightly | |
matmul = lambda a,b: triton_matmul(a,b, torch.float16) # stable | |
torch.manual_seed(0) | |
for size in (512,4096): | |
a,b = [torch.randn((size,size), device='cuda', dtype=torch.float16) for _ in 'ab'] | |
print(f"triton_output={matmul(a,b)}") | |
print(f"torch_output={torch.matmul(a,b)}") | |
# stolen from https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=['M', 'N', 'K'], | |
x_vals=[128 * i for i in range(2, 33)], | |
line_arg='provider', | |
line_vals=['cublas', 'triton'], | |
line_names=["cuBLAS", "Triton"], | |
styles=[('green', '-'), ('blue', '-')], | |
ylabel="TFLOPS", | |
plot_name="Matmul perf (3090)", # NOTE: This will *not* add a graph plot header unless you edit triton/testing.py | |
args={}, | |
)) | |
def benchmark(M, N, K, provider): | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
quantiles = [0.5, 0.2, 0.8] | |
f = torch.matmul if provider == 'cublas' else matmul # pytorch defers to cuBLAS. | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: f(a, b), quantiles=quantiles) | |
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) | |
t = perf(ms), perf(max_ms), perf(min_ms) | |
print(provider, t) | |
return t | |
benchmark.run(show_plots=True, print_data=True, save_path=str(savepath)) | |
''' | |
To run this file, you should either: | |
* remove the `prune_configs_by` key in triton/ops/matmul.py (which prunes kernels that are useful in larger blocks early on) | |
* import the matmul impl from the matmul tutorial in https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py instead. | |
''' |
Author
152334H
commented
Jan 2, 2024
All tests done on torch==2.1.1+cu121, triton==2.1.0. GPU used is a Zotac 3090 under ~no load (some minimal mem allocated to window manager)
It is currently impractical to just directly import triton.ops.matmul into existing torch code for a variety of reasons, including:
- autotuner being extremely slow on launch
- breaking internal torch optimizations (which presumably rely on the exact
@
op used) to fuse layers and etc. - autotuner recompiling (or something?) every time a different shaped input is used. Particularly bad for fine-tuning without a global fixed padding length.
- will not work for anything other than fp16 (bf16 is a no-go)
- will most likely be too imprecise for training out-of-the-box. Less certain about this part, considering fp8...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment