Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created December 13, 2024 01:12
Show Gist options
  • Save davidberard98/747bbb9345a965301c44453ee7758755 to your computer and use it in GitHub Desktop.
Save davidberard98/747bbb9345a965301c44453ee7758755 to your computer and use it in GitHub Desktop.
# Results:
#
# Vertical indices ms: 2.8862898349761963
# Horizontal indices ms: 0.3734990060329437
import torch
import triton
import triton.language as tl
BLOCK_SIZE = 64
@triton.jit
def kernel(
indices_ptr,
dest_ptr,
stride_im,
stride_in,
M,
N,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(0)
idx_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
idx_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask = ((idx_m[:, None] < M) & (idx_n[None, :] < N))
indices = tl.load(indices_ptr + idx_m[:, None] * stride_im + idx_n[None, :] * stride_in, mask=mask)
# data = tl.full([BLOCK_M, BLOCK_N], 1.0, tl.float32)
tl.atomic_add(dest_ptr + indices, indices , mask=mask)
def run(indices, dest):
M, N = indices.size()
BLOCK_M, BLOCK_N = BLOCK_SIZE, BLOCK_SIZE
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), 1)
kernel[grid](indices, dest, *indices.stride(), M, N, BLOCK_M, BLOCK_N)
def vertical():
# Indices:
# 0 1 2 3 .. 1023
# 0 1 2 3 .. 1023
# 0 1 2 3 .. 1023
# .. .. .. .. ..
# (BLOCK_SIZE x 1024)
N = 1024 * 32
indices = torch.arange(0, N, device="cuda", dtype=torch.int32).unsqueeze(0).expand(BLOCK_SIZE, N).contiguous()
offsets = torch.zeros(N, device="cuda")
return indices, offsets
def horizontal():
# Indices:
# 0 0 0 0 ... 0 32 32 32 ... 992
# 1 1 1 1 ... 1 33 33 33 ... 993
#
# 31 31 31 31 ...31 63 63 63 ... 1023
# (32 x 1024)
N = 1024 * 32
indices = torch.arange(0, N, device="cuda", dtype=torch.int32).reshape(BLOCK_SIZE, N//BLOCK_SIZE).unsqueeze(2).expand(BLOCK_SIZE, N//BLOCK_SIZE, BLOCK_SIZE).permute(1, 0, 2).reshape(BLOCK_SIZE, N)
offsets = torch.zeros(N, device="cuda")
return indices, offsets
def get_perf(indices, dest):
fn = lambda: run(indices, dest)
return triton.testing.do_bench(fn, return_mode="mean")
print("Vertical indices ms: ", get_perf(*vertical()))
print("Horizontal indices ms: ", get_perf(*horizontal()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment