Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created January 24, 2025 17:45
Show Gist options
  • Save davidberard98/a17b200a7f4a5cb5963512c356b9fec6 to your computer and use it in GitHub Desktop.
Save davidberard98/a17b200a7f4a5cb5963512c356b9fec6 to your computer and use it in GitHub Desktop.
import torch
import triton
import triton.language as tl
@triton.jit
def kernel(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr, STRING_CONSTEXPR: tl.constexpr):
offsets = tl.arange(0, BLOCK_SIZE)
data = tl.load(in_ptr + offsets)
if STRING_CONSTEXPR == "sin":
data = data.sin()
else:
data = data.cos()
tl.store(out_ptr + offsets, data)
def run_kernel(x):
y = torch.empty_like(x)
kernel[(1,)](x, y, x.numel(), "STRING_HAS_\"QUOTE_AND\\BACKSLASH")
return y
def fn(x):
return run_kernel(x)
x = torch.randn(128, device="cuda")
fn_c = torch.compile(fn)
fn_c(x)
with torch.profiler.profile(activities=torch.profiler.supported_activities(), record_shapes=True) as prof:
fn_c(x)
fn_c(x)
fn_c(x)
# breakpoint()
prof.export_chrome_trace("udtk.json")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment