Created
January 24, 2025 17:45
-
-
Save davidberard98/a17b200a7f4a5cb5963512c356b9fec6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def kernel(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr, STRING_CONSTEXPR: tl.constexpr): | |
offsets = tl.arange(0, BLOCK_SIZE) | |
data = tl.load(in_ptr + offsets) | |
if STRING_CONSTEXPR == "sin": | |
data = data.sin() | |
else: | |
data = data.cos() | |
tl.store(out_ptr + offsets, data) | |
def run_kernel(x): | |
y = torch.empty_like(x) | |
kernel[(1,)](x, y, x.numel(), "STRING_HAS_\"QUOTE_AND\\BACKSLASH") | |
return y | |
def fn(x): | |
return run_kernel(x) | |
x = torch.randn(128, device="cuda") | |
fn_c = torch.compile(fn) | |
fn_c(x) | |
with torch.profiler.profile(activities=torch.profiler.supported_activities(), record_shapes=True) as prof: | |
fn_c(x) | |
fn_c(x) | |
fn_c(x) | |
# breakpoint() | |
prof.export_chrome_trace("udtk.json") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment