Last active
November 22, 2023 01:50
-
-
Save zhuangh/756a2002cb4ed0b863fed784fd3da64f to your computer and use it in GitHub Desktop.
run_triton.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
import torch.nn.functional as F | |
import time | |
@triton.jit | |
def add_kernel(x_ptr, y_ptr, output_ptr, N, | |
BLOCK_SIZE: tl.constexpr): | |
pid = tl.program_id(0) | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < N | |
x = tl.load(x_ptr + offsets, mask=mask) | |
y = tl.load(y_ptr + offsets, mask=mask) | |
output = x + y | |
tl.store(output_ptr+offsets, output, mask=mask) | |
def add(x: torch.Tensor, y: torch.Tensor): | |
# We need to preallocate the output. | |
output = torch.empty_like(x) | |
assert x.is_cuda and y.is_cuda and output.is_cuda | |
n_elements = output.numel() | |
# The SPMD launch grid denotes the number of kernel instances that run in parallel. | |
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. | |
# In this case, we use a 1D grid where the size is the number of blocks: | |
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) | |
# NOTE: | |
# - Each torch.tensor object is implicitly converted into a pointer to its first element. | |
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. | |
# - Don't forget to pass meta-parameters as keywords arguments. | |
compiled = add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) | |
print("IR", compiled.asm['ttir']) | |
print("TTGIR", compiled.asm['ttgir']) | |
print("PTX", compiled.asm['ptx']) | |
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still | |
# running asynchronously at this point. | |
return output | |
torch.manual_seed(0) | |
size = 98432 | |
x = torch.rand(size, device='cuda') | |
y = torch.rand(size, device='cuda') | |
output_torch = x + y | |
output_triton = add(x, y) | |
print(output_torch) | |
print(output_triton) | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=['size'], # Argument names to use as an x-axis for the plot. | |
x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. | |
x_log=True, # x axis is logarithmic. | |
line_arg='provider', # Argument name whose value corresponds to a different line in the plot. | |
line_vals=['triton', 'torch'], # Possible values for `line_arg`. | |
line_names=['Triton', 'Torch'], # Label name for the lines. | |
styles=[('blue', '-'), ('green', '-')], # Line styles. | |
ylabel='GB/s', # Label name for the y-axis. | |
plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. | |
args={}, # Values for function arguments not in `x_names` and `y_name`. | |
)) | |
def benchmark(size, provider): | |
x = torch.rand(size, device='cuda', dtype=torch.float32) | |
y = torch.rand(size, device='cuda', dtype=torch.float32) | |
quantiles = [0.5, 0.2, 0.8] | |
if provider == 'torch': | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) | |
if provider == 'triton': | |
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) | |
gbps = lambda ms: 12 * size / ms * 1e-6 | |
return gbps(ms), gbps(max_ms), gbps(min_ms) | |
benchmark.run(print_data=True, show_plots=True) | |
print(dir(add_kernel.cache)) | |
with open("add_kernel.ptx", "w") as a: | |
print(list(add_kernel.cache[0].values())[0].asm['ptx'], file=a) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment