Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created November 25, 2025 02:16
Show Gist options
  • Select an option

  • Save shunting314/7b6ad5da5dd4537df35e5439874a460a to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/7b6ad5da5dd4537df35e5439874a460a to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.pointwise(
size_hints={'x': 67108864}, tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*i64', 'in_ptr2': '*bf16', 'out_ptr0': '*bf16', 'xnumel_0': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'SequentialComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': None, 'default_config': None, 'no_x_dim_0': False, 'xnumel_0': None}, 'kernel_name': 'triton_poi_fused_1_3', 'mutated_arg_names': [], 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_poi_fused_1_3(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel_0, XBLOCK : tl.constexpr):
pid = tl.program_id(0)
num_xblocks_0 = tl.cdiv(xnumel_0, XBLOCK)
if pid < num_xblocks_0:
pid_offset = pid
r0_numel = 1
xoffset = pid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel_0
x0 = (xindex % 128)
x1 = ((xindex // 128) % 32)
x2 = xindex // 4096
x4 = xindex
tmp0 = x0
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 64, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (128*x1 + 6144*x2 + (x0)), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr1 + (x2), tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp7 = tl.full([XBLOCK], 8192, tl.int32)
tmp8 = tmp6 + tmp7
tmp9 = tmp6 < 0
tmp10 = tl.where(tmp9, tmp8, tmp6)
tl.device_assert(((0 <= tl.broadcast_to(tmp10, [XBLOCK])) & (tl.broadcast_to(tmp10, [XBLOCK]) < 8192)) | ~(tmp4 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp10, [XBLOCK]) < 8192")
tmp12 = tl.load(in_ptr2 + (128*tmp10 + (x0)), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp5 * tmp12
tmp14 = tl.load(in_ptr0 + (64 + 128*x1 + 6144*x2 + (x0)), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (64 + 128*tmp10 + (x0)), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp16 = tmp14 * tmp15
tmp17 = tmp13 - tmp16
tmp18 = tl.full(tmp17.shape, 0.0, tmp17.dtype)
tmp19 = tl.where(tmp4, tmp17, tmp18)
tmp20 = tmp0 >= tmp3
tmp21 = tl.full([1], 128, tl.int64)
tmp22 = tmp0 < tmp21
tmp23 = tl.load(in_ptr0 + (64 + 128*x1 + 6144*x2 + ((-64) + x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp24 = tl.load(in_ptr1 + (x2), tmp20 & xmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.full([XBLOCK], 8192, tl.int32)
tmp26 = tmp24 + tmp25
tmp27 = tmp24 < 0
tmp28 = tl.where(tmp27, tmp26, tmp24)
tl.device_assert(((0 <= tl.broadcast_to(tmp28, [XBLOCK])) & (tl.broadcast_to(tmp28, [XBLOCK]) < 8192)) | ~(tmp20 & xmask), "index out of bounds: 0 <= tl.broadcast_to(tmp28, [XBLOCK]) < 8192")
tmp30 = tl.load(in_ptr2 + (128*tmp28 + ((-64) + x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp31 = tmp23 * tmp30
tmp32 = tl.load(in_ptr0 + (128*x1 + 6144*x2 + ((-64) + x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tl.load(in_ptr2 + (64 + 128*tmp28 + ((-64) + x0)), tmp20 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp34 = tmp32 * tmp33
tmp35 = tmp31 + tmp34
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp20, tmp35, tmp36)
tmp38 = tl.where(tmp4, tmp19, tmp37)
tl.store(out_ptr0 + (x4), tmp38, xmask)
else:
pass
def get_args():
arg_0 = rand_strided((16384, 6144), (6144, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384,), (1,), device='cuda:0', dtype=torch.int64)
arg_2 = rand_strided((8192, 128), (128, 1), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((16384, 32, 128), (4096, 128, 1), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, 67108864,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_1_3.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_1_3.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(call, fn_args=(args,), device=cuda,rep=40)
num_gb = 0
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment