Created
November 22, 2023 07:52
-
-
Save zhuangh/bd1ae531322a397458d867dc3067648e to your computer and use it in GitHub Desktop.
run_matmul_gtx1060.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
import torch.nn.functional as F | |
@triton.jit | |
def matmul_kernel( | |
a_ptr, b_ptr, c_ptr, | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, | |
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, | |
): | |
offs_m = tl.arange(0, BLOCK_SIZE_M) | |
offs_n = tl.arange(0, BLOCK_SIZE_N) | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak | |
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, K, BLOCK_SIZE_K): | |
a = tl.load(a_ptrs) | |
b = tl.load(b_ptrs) | |
accumulator += tl.dot(a, b) | |
a_ptrs += BLOCK_SIZE_K * stride_ak | |
b_ptrs += BLOCK_SIZE_K * stride_bk | |
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn | |
tl.store(c_ptrs, accumulator) | |
X = torch.normal(0, 1, size=(1024, 1024), device='cuda') | |
Y = torch.empty_like(X) | |
def matmul(a, b, activation=""): | |
# Check constraints. | |
assert a.shape[1] == b.shape[0], f"Incompatible dimensions {a.shape[1]} != {b.shape[0]}" | |
assert a.is_contiguous(), "Matrix A must be contiguous" | |
assert b.is_contiguous(), "Matrix B must be contiguous" | |
M, K = a.shape | |
K, N = b.shape | |
# Allocates output. | |
c = torch.empty((M, N), device=a.device, dtype=a.dtype) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) | |
t = 16 | |
compiled = matmul_kernel[grid]( | |
a, b, c, # | |
a.stride(0), a.stride(1), # | |
b.stride(0), b.stride(1), # | |
c.stride(0), c.stride(1), # | |
#ACTIVATION=activation # | |
M, N, K, | |
t,t,t | |
) | |
print("IR", compiled.asm['ttir']) | |
print("TTGIR", compiled.asm['ttgir']) | |
print("LLIR", compiled.asm['llir']) | |
return c | |
Z = matmul(X, Y) | |
print(dir(matmul_kernel.cache)) | |
with open("matmul_kernel.ptx", "w") as a: | |
print(list(matmul_kernel.cache[0].values())[0].asm['ptx'], file=a) | |
# | |
# @triton.testing.perf_report( | |
# triton.testing.Benchmark( | |
# x_names=['size'], # Argument names to use as an x-axis for the plot. | |
# x_vals=[2**i for i in range(4, 10, 1)], # Different possible values for `x_name`. | |
# x_log=True, # x axis is logarithmic. | |
# line_arg='provider', # Argument name whose value corresponds to a different line in the plot. | |
# line_vals=['triton', 'torch'], # Possible values for `line_arg`. | |
# line_names=['Triton', 'Torch'], # Label name for the lines. | |
# styles=[('blue', '-'), ('green', '-')], # Line styles. | |
# ylabel='GB/s', # Label name for the y-axis. | |
# plot_name='mat-mul-performance', # Name for the plot. Used also as a file name for saving the plot. | |
# args={}, # Values for function arguments not in `x_names` and `y_name`. | |
# )) | |
# def benchmark(size, provider): | |
# print("!!!!!!!!!!!", size) | |
# x = torch.rand((size,size), device='cuda', dtype=torch.float32) | |
# y = torch.rand((size,size), device='cuda', dtype=torch.float32) | |
# quantiles = [0.5, 0.2, 0.8] | |
# if provider == 'torch': | |
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) | |
# if provider == 'triton': | |
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(x, y), quantiles=quantiles) | |
# gbps = lambda ms: 12 * size / ms * 1e-6 | |
# return gbps(ms), gbps(max_ms), gbps(min_ms) | |
# | |
# output_torch = x + y | |
# output_triton = matmul(x, y) | |
# | |
# benchmark.run(print_data=True, show_plots=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment