Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created November 25, 2025 00:25
Show Gist options
  • Select an option

  • Save shunting314/479f7ae915ac982db4767ea16401e42c to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/479f7ae915ac982db4767ea16401e42c to your computer and use it in GitHub Desktop.
s36 = 3
import os
os.environ['VLLM_TORCH_PROFILER_DIR'] = '/tmp/myprofile'
os.environ['TORCH_TRACE'] = '/tmp/tlp'
os.environ['INDUCTOR_PROVENANCE'] = '1'
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_shunting/'
os.environ['TORCHINDUCTOR_BENCHMARK_KERNEL'] = '1'
os.environ['TORCH_LOGS_FORMAT'] = '%(levelname)s: %(message)s'
os.environ['TORCHINDUCTOR_FX_GRAPH_CACHE_DEFAULT'] = '1'
os.environ['TORCHINDUCTOR_UNIQUE_KERNEL_NAMES'] = '1'
os.environ['TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE'] = '0'
os.environ['INDUCTOR_TEST_DISABLE_FRESH_CACHE'] = '1'
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
os.environ['TRITON_CACHE_DIR'] = '/tmp/torchinductor_shunting/triton/0'
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.enable_cpp_symbolic_shape_guards = False
torch._dynamo.config.enable_aot_compile = True
torch._inductor.config.enable_auto_functionalized_v2 = True
torch._inductor.config.graph_partition = True
torch._inductor.config.deterministic = False
torch._inductor.config.benchmark_kernel = True
torch._inductor.config.combo_kernels = False
torch._inductor.config.benchmark_combo_kernel = False
torch._inductor.config.compile_threads = 1
torch._inductor.config.triton.autotune_at_compile_time = None
torch._inductor.config.triton.store_cubin = False
torch._inductor.config.test_configs.runtime_triton_dtype_assert = False
torch._functorch.config.functionalize_rng_ops = False
torch._functorch.config.bundled_autograd_cache = False
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access = True
torch._functorch.config.unlift_effect_tokens = True
torch._functorch.config.force_non_lazy_backward_lowering = False
torch._functorch.config.selective_decompose = False
isolate_fails_code_str = None
# torch version: 2.10.0a0+git1095402
# torch cuda version: 12.9
# torch git version: 10954021311cc837423f69ecf8ea53dcbdff2c50
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2025 NVIDIA Corporation
# Built on Tue_May_27_02:21:03_PDT_2025
# Cuda compilation tools, release 12.9, V12.9.86
# Build cuda_12.9.r12.9/compiler.36037853_0
# GPU Hardware Info:
# NVIDIA B200 : 8
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, arg1_1, arg2_1):
bs = arg1_1.size(0)
convert_element_type = torch.ops.prims.convert_element_type.default(arg1_1, torch.float32); arg1_1 = None
unsqueeze = torch.ops.aten.unsqueeze.default(arg2_1, 1); arg2_1 = None
scalar_tensor_default = torch.ops.aten.scalar_tensor.default(1, dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)
ge_scalar = torch.ops.aten.ge.Scalar(unsqueeze, 0)
neg_default = torch.ops.aten.neg.default(scalar_tensor_default)
where_self = torch.ops.aten.where.self(ge_scalar, scalar_tensor_default, neg_default); ge_scalar = scalar_tensor_default = neg_default = None
mul_tensor = torch.ops.aten.mul.Tensor(convert_element_type, where_self); convert_element_type = None
amax_default = torch.ops.aten.amax.default(mul_tensor, [-1], True)
sub_tensor = torch.ops.aten.sub.Tensor(mul_tensor, amax_default); mul_tensor = amax_default = None
mul_tensor_1 = torch.ops.aten.mul.Tensor(where_self, unsqueeze); where_self = unsqueeze = None
div_tensor = torch.ops.aten.div.Tensor(sub_tensor, mul_tensor_1); sub_tensor = mul_tensor_1 = None
exp = torch.ops.aten.exp.default(div_tensor); div_tensor = None
sum_1 = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
div_1 = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None
inductor_seeds_default = torch.ops.prims.inductor_seeds.default(1, device(type='cuda', index=0))
inductor_lookup_seed_default = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None
inductor_random_default = torch.ops.prims.inductor_random.default([bs, 128256], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None
ge_10 = torch.ops.aten.ge.Scalar(inductor_random_default, 0.9999999403953552)
log = torch.ops.aten.log.default(inductor_random_default); inductor_random_default = None
full_default = torch.ops.aten.full.default([], -5.960464477539063e-08, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where = torch.ops.aten.where.self(ge_10, full_default, log); ge_10 = full_default = log = None
mul_53 = torch.ops.aten.mul.Tensor(where, -1.0); where = None
div_2 = torch.ops.aten.div.Tensor(div_1, mul_53); div_1 = mul_53 = None
argmax = torch.ops.aten.argmax.default(div_2, -1); div_2 = None
convert_element_type_1 = torch.ops.prims.convert_element_type.default(argmax, torch.int32); argmax = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(convert_element_type_1, -1); convert_element_type_1 = None
return (unsqueeze_1,)
def load_args(reader):
buf0 = reader.storage(None, 256512*s36, device=device(type='cuda', index=0), dtype_hint=torch.bfloat16)
reader.tensor(buf0, (s36, 128256), dtype=torch.bfloat16, is_leaf=True) # arg1_1
buf1 = reader.storage(None, 4096, device=device(type='cuda', index=0))
reader.tensor(buf1, (s36,), is_leaf=True) # arg2_1
def get_args(bs):
return torch.randn(bs, 128256, dtype=torch.bfloat16, device="cuda"), torch.randn(bs, dtype=torch.float, device="cuda")
load_args._version = 0
mod = Repro()
if __name__ == '__main__':
from torch._dynamo.repro.after_aot import run_repro
with torch.no_grad():
# run_repro(mod, load_args, accuracy=False, command='run', save_dir=None, tracing_mode='symbolic', check_str=None); exit()
# To run it separately, do
mod, args = run_repro(mod, load_args, accuracy=False, command='get_args', save_dir=None, tracing_mode='symbolic', check_str=None)
opt_mod = torch.compile(mod)
args_3 = get_args(3)
args_256 = get_args(256)
print("Warmup")
args = args_256
torch._dynamo.mark_dynamic(args[0], 0)
torch._dynamo.mark_dynamic(args[1], 0)
for _ in range(10):
opt_mod(*args)
import triton
ms = triton.testing.do_bench(lambda: opt_mod(*args_256))
print(f"{ms=:.3f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment