Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created November 25, 2025 01:11
Show Gist options
  • Select an option

  • Save shunting314/532ba2e60ddb96057ebd7a74b6ea3add to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/532ba2e60ddb96057ebd7a74b6ea3add to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 256, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'out_ptr4': '*i32', 'load_seed_offset': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__softmax__to_copy_amax_argmax_div_exponential_ge_mul_neg_scalar_tensor_sub_unsqueeze_where_5_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 2, 'num_reduction': 3, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False, 'kernel_num_gb': 0.197003272, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__softmax__to_copy_amax_argmax_div_exponential_ge_mul_neg_scalar_tensor_sub_unsqueeze_where_5_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr4, load_seed_offset, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 128256
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = r0_1 + 128256*x0
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
tl.store(out_ptr0 + (r0_1 + 128256*x0), tmp2, r0_mask & xmask)
tmp5 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
_tmp13 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp3 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp6 = 0.0
tmp7 = tmp5 >= tmp6
tmp8 = 1.0
tmp9 = -1.0
tmp10 = tl.where(tmp7, tmp8, tmp9)
tmp11 = tmp4 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
tmp14 = triton_helpers.maximum(_tmp13, tmp12)
_tmp13 = tl.where(r0_mask & xmask, tmp14, _tmp13)
tmp13 = triton_helpers.max2(_tmp13, 1)[:, None]
_tmp28 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp15 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp16 = tmp15.to(tl.float32)
tmp17 = 0.0
tmp18 = tmp5 >= tmp17
tmp19 = 1.0
tmp20 = -1.0
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp16 * tmp21
tmp23 = tmp22 - tmp13
tmp24 = tmp21 * tmp5
tmp25 = (tmp23 / tmp24)
tmp26 = libdevice.exp(tmp25)
tmp27 = tl.broadcast_to(tmp26, [XBLOCK, R0_BLOCK])
tmp29 = _tmp28 + tmp27
_tmp28 = tl.where(r0_mask & xmask, tmp29, _tmp28)
tmp28 = tl.sum(_tmp28, 1)[:, None]
_tmp52 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp52_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp30 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp43 = tl.load(out_ptr0 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp31 = tmp30.to(tl.float32)
tmp32 = 0.0
tmp33 = tmp5 >= tmp32
tmp34 = 1.0
tmp35 = -1.0
tmp36 = tl.where(tmp33, tmp34, tmp35)
tmp37 = tmp31 * tmp36
tmp38 = tmp37 - tmp13
tmp39 = tmp36 * tmp5
tmp40 = (tmp38 / tmp39)
tmp41 = libdevice.exp(tmp40)
tmp42 = (tmp41 / tmp28)
tmp44 = 0.9999999403953552
tmp45 = tmp43 >= tmp44
tmp46 = tl_math.log(tmp43)
tmp47 = -5.960464477539063e-08
tmp48 = tl.where(tmp45, tmp47, tmp46)
tmp49 = tmp48 * tmp35
tmp50 = (tmp42 / tmp49)
tmp51 = tl.broadcast_to(tmp50, [XBLOCK, R0_BLOCK])
_tmp52_next, _tmp52_index_next = triton_helpers.maximum_with_index(
_tmp52, _tmp52_index, tmp51, rindex
)
_tmp52 = tl.where(r0_mask & xmask, _tmp52_next, _tmp52)
_tmp52_index = tl.where(r0_mask & xmask, _tmp52_index_next, _tmp52_index)
tmp52_val, tmp52_idx = triton_helpers.max_with_index(_tmp52, _tmp52_index, 1)
tmp52 = tmp52_idx[:, None]
tmp53 = tmp52.to(tl.int32)
tl.store(out_ptr4 + (x0), tmp53, xmask)
def get_args():
arg_0 = rand_strided((1,), (1,), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((256, 128256), (128256, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((256,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((256, 128256), (128256, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((256,), (1,), device='cuda:0', dtype=torch.int32)
arg_5 = 0
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, 256, 128256,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__softmax__to_copy_amax_argmax_div_exponential_ge_mul_neg_scalar_tensor_sub_unsqueeze_where_5_0.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__softmax__to_copy_amax_argmax_div_exponential_ge_mul_neg_scalar_tensor_sub_unsqueeze_where_5_0.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.197003272
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment