Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active May 1, 2025 23:39
Show Gist options
  • Save davidberard98/7fd6af7e6291787c246c705945a25554 to your computer and use it in GitHub Desktop.
Save davidberard98/7fd6af7e6291787c246c705945a25554 to your computer and use it in GitHub Desktop.
import torch
import triton
def fn(x, y, buckets):
buckets = torch.bucketize(x, buckets)
return buckets[:, None] * y[None, :]
lengths = torch.randint(512, (1024,), device="cuda", dtype=torch.int32)
total = lengths.sum().item()
new_total = ((total + 15) // 16) * 16
lengths[-1] += new_total - total
assert lengths.sum().item() % 16 == 0
buckets = torch.empty(1025, device="cuda", dtype=torch.int32)
torch.cumsum(lengths, dim=0, out=buckets[1:])
x = torch.rand(2**15, device="cuda") * new_total
y = torch.randn(384, device="cuda")
compile_fn = torch.compile(fn)
compile_fn(x, y, buckets)
eager_ms = triton.testing.do_bench(lambda: fn(x, y, buckets))
compile_ms = triton.testing.do_bench(lambda: compile_fn(x, y, buckets))
print(f" eager ms: {eager_ms}")
print(f"torch.compile ms: {compile_ms}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment