Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/8a16e1bca63e0f75fbc61a182ca05a12 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/8a16e1bca63e0f75fbc61a182ca05a12 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_shunting/sj/csj2knbbznawjmwg3xtedoebcqs5idli5oewacj76qxn5wisa27b.py
# Topologically Sorted Source Nodes: [convert_element_type, unsqueeze, ge, scalar_tensor, neg, where, , mul_1, div, exp, sum_1, div_1, inductor_lookup_seed, inductor_random, ge_1, full, log, where_1, mul_2, div_2, argmax, convert_element_type_1], Original ATen: [prims.convert_element_type, aten.unsqueeze, aten.ge, aten.scalar_tensor, aten.neg, aten.where, aten.mul, aten.amax, aten.sub, aten.div, aten.exp, aten.sum, prims.inductor_lookup_seed, prims.inductor_random, aten.full, aten.log, aten.argmax]
# Source node to ATen node mapping:
# => amax_default, ge_scalar, mul_tensor, mul_tensor_1, mul_tensor_2, neg_default, scalar_tensor_default, sub_tensor, where_self
# argmax => argmax
# convert_element_type => convert_element_type
# convert_element_type_1 => convert_element_type_1
# div => div
# div_1 => div_1
# div_2 => div_2
# exp => exp
# full => full_default_2
# ge => ge
# ge_1 => ge_1
# inductor_lookup_seed => inductor_lookup_seed
# inductor_random => inductor_random
# log => log
# mul_1 => mul_10
# mul_2 => mul_26
# neg => full_default_1
# scalar_tensor => full_default
# sum_1 => sum_1
# unsqueeze => unsqueeze
# where => where
# where_1 => where_1
# Graph fragment:
# %inductor_seeds : Tensor "i64[1][1]cuda:0" = PlaceHolder[target=inductor_seeds]
# %arg1_1 : Tensor "bf16[s26, 128256][128256, 1]cuda:0" = PlaceHolder[target=arg1_1]
# %arg3_1 : Tensor "f32[s26][1]cuda:0" = PlaceHolder[target=arg3_1]
# %amax_default : Tensor "f32[s26, 1][1, s26]cuda:0" = PlaceHolder[target=amax_default]
# %sum_1 : Tensor "f32[s26, 1][1, s26]cuda:0" = PlaceHolder[target=sum_1]
# %inductor_random : Tensor "f32[s26, 128256][128256, 1]cuda:0" = PlaceHolder[target=inductor_random]
# %argmax : Tensor "i64[s26][1]cuda:0" = PlaceHolder[target=argmax]
# %convert_element_type : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%arg1_1, torch.float32), kwargs = {})
# %unsqueeze : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arg3_1, 1), kwargs = {})
# %ge : Tensor "b8[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%unsqueeze, 0), kwargs = {})
# %full_default : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 1.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %full_default_1 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], -1.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.where.self](args = (%ge, %full_default, %full_default_1), kwargs = {})
# %ge_scalar : Tensor "b8[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%where, 0), kwargs = {})
# %scalar_tensor_default : Tensor "f32[][]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (1,), kwargs = {dtype: torch.float32, device: cuda:0, pin_memory: False})
# %neg_default : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%scalar_tensor_default,), kwargs = {})
# %where_self : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.where.self](args = (%ge_scalar, %scalar_tensor_default, %neg_default), kwargs = {})
# %mul_tensor : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %where_self), kwargs = {})
# %amax_default : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%mul_tensor, [-1], True), kwargs = {})
# %sub_tensor : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_tensor, %amax_default), kwargs = {})
# %mul_tensor_1 : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%where_self, %where), kwargs = {})
# %mul_tensor_2 : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_tensor, %mul_tensor_1), kwargs = {})
# %mul_10 : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%where, %unsqueeze), kwargs = {})
# %div : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor_2, %mul_10), kwargs = {})
# %exp : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.exp.default](args = (%div,), kwargs = {})
# %sum_1 : Tensor "f32[s26, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [-1], True), kwargs = {})
# %div_1 : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%exp, %sum_1), kwargs = {})
# %inductor_lookup_seed : Tensor "i64[][]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.inductor_lookup_seed.default](args = (%inductor_seeds, 0), kwargs = {})
# %inductor_random : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.inductor_random.default](args = ([%arg0_1, 128256], %inductor_lookup_seed, rand), kwargs = {})
# %ge_1 : Tensor "b8[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.ge.Scalar](args = (%inductor_random, 0.9999999403953552), kwargs = {})
# %full_default_2 : Tensor "f32[][]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], -5.960464477539063e-08), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %log : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%inductor_random,), kwargs = {})
# %where_1 : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ge_1, %full_default_2, %log), kwargs = {})
# %mul_26 : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%where_1, -1.0), kwargs = {})
# %div_2 : Tensor "f32[s26, 128256][128256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%div_1, %mul_26), kwargs = {})
# %argmax : Tensor "i64[s26][1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.argmax.default](args = (%div_2, -1), kwargs = {})
# %convert_element_type_1 : Tensor "i32[s26][1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%argmax, torch.int32), kwargs = {})
# return %inductor_random,%amax_default,%sum_1,%argmax,%convert_element_type_1
triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0 = async_compile.triton('triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 256, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'out_ptr4': '*i32', 'load_seed_offset': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 2, 'num_reduction': 3, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False, 'kernel_num_gb': 0.197003272, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr4, load_seed_offset, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 128256
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = r0_1 + 128256*x0
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
tl.store(out_ptr0 + (r0_1 + 128256*x0), tmp2, r0_mask & xmask)
tmp5 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last')
_tmp15 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp3 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tmp3.to(tl.float32)
tmp6 = 0.0
tmp7 = tmp5 >= tmp6
tmp8 = 1.0
tmp9 = -1.0
tmp10 = tl.where(tmp7, tmp8, tmp9)
tmp11 = tmp10 >= tmp6
tmp12 = tl.where(tmp11, tmp8, tmp9)
tmp13 = tmp4 * tmp12
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp16 = triton_helpers.maximum(_tmp15, tmp14)
_tmp15 = tl.where(r0_mask & xmask, tmp16, _tmp15)
tmp15 = triton_helpers.max2(_tmp15, 1)[:, None]
_tmp34 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp17 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp18 = tmp17.to(tl.float32)
tmp19 = 0.0
tmp20 = tmp5 >= tmp19
tmp21 = 1.0
tmp22 = -1.0
tmp23 = tl.where(tmp20, tmp21, tmp22)
tmp24 = tmp23 >= tmp19
tmp25 = tl.where(tmp24, tmp21, tmp22)
tmp26 = tmp18 * tmp25
tmp27 = tmp26 - tmp15
tmp28 = tmp25 * tmp23
tmp29 = tmp27 * tmp28
tmp30 = tmp23 * tmp5
tmp31 = (tmp29 / tmp30)
tmp32 = libdevice.exp(tmp31)
tmp33 = tl.broadcast_to(tmp32, [XBLOCK, R0_BLOCK])
tmp35 = _tmp34 + tmp33
_tmp34 = tl.where(r0_mask & xmask, tmp35, _tmp34)
tmp34 = tl.sum(_tmp34, 1)[:, None]
_tmp62 = tl.full([XBLOCK, R0_BLOCK], float("-inf"), tl.float32)
_tmp62_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp36 = tl.load(in_ptr1 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp53 = tl.load(out_ptr0 + (r0_1 + 128256*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp37 = tmp36.to(tl.float32)
tmp38 = 0.0
tmp39 = tmp5 >= tmp38
tmp40 = 1.0
tmp41 = -1.0
tmp42 = tl.where(tmp39, tmp40, tmp41)
tmp43 = tmp42 >= tmp38
tmp44 = tl.where(tmp43, tmp40, tmp41)
tmp45 = tmp37 * tmp44
tmp46 = tmp45 - tmp15
tmp47 = tmp44 * tmp42
tmp48 = tmp46 * tmp47
tmp49 = tmp42 * tmp5
tmp50 = (tmp48 / tmp49)
tmp51 = libdevice.exp(tmp50)
tmp52 = (tmp51 / tmp34)
tmp54 = 0.9999999403953552
tmp55 = tmp53 >= tmp54
tmp56 = tl_math.log(tmp53)
tmp57 = -5.960464477539063e-08
tmp58 = tl.where(tmp55, tmp57, tmp56)
tmp59 = tmp58 * tmp41
tmp60 = (tmp52 / tmp59)
tmp61 = tl.broadcast_to(tmp60, [XBLOCK, R0_BLOCK])
_tmp62_next, _tmp62_index_next = triton_helpers.maximum_with_index(
_tmp62, _tmp62_index, tmp61, rindex
)
_tmp62 = tl.where(r0_mask & xmask, _tmp62_next, _tmp62)
_tmp62_index = tl.where(r0_mask & xmask, _tmp62_index_next, _tmp62_index)
tmp62_val, tmp62_idx = triton_helpers.max_with_index(_tmp62, _tmp62_index, 1)
tmp62 = tmp62_idx[:, None]
tmp63 = tmp62.to(tl.int32)
tl.store(out_ptr4 + (x0), tmp63, xmask)
def get_args():
arg_0 = rand_strided((1,), (1,), device='cuda:0', dtype=torch.int64)
arg_1 = rand_strided((256, 128256), (128256, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((256,), (1,), device='cuda:0', dtype=torch.float32)
arg_3 = rand_strided((256, 128256), (128256, 1), device='cuda:0', dtype=torch.float32)
arg_4 = rand_strided((256,), (1,), device='cuda:0', dtype=torch.int32)
arg_5 = 0
return arg_0, arg_1, arg_2, arg_3, arg_4, arg_5, 256, 128256,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.197003272
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, arg1_1, arg2_1, arg3_1 = args
args.clear()
s26 = arg0_1
s9 = arg2_1
assert_size_stride(arg1_1, (s26, 128256), (128256, 1))
assert_size_stride(arg3_1, (s26, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf2 = empty_strided_cuda((1, ), (1, ), torch.int64)
# Topologically Sorted Source Nodes: [inductor_seeds], Original ATen: [prims.inductor_seeds]
# [Provenance debug handles] aten.randint.low_out:1
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [1], out=buf2)
buf3 = empty_strided_cuda((s26, 128256), (128256, 1), torch.float32)
buf5 = empty_strided_cuda((s26, ), (1, ), torch.int32)
# Topologically Sorted Source Nodes: [convert_element_type, unsqueeze, ge, scalar_tensor, neg, where, , mul_1, div, exp, sum_1, div_1, inductor_lookup_seed, inductor_random, ge_1, full, log, where_1, mul_2, div_2, argmax, convert_element_type_1], Original ATen: [prims.convert_element_type, aten.unsqueeze, aten.ge, aten.scalar_tensor, aten.neg, aten.where, aten.mul, aten.amax, aten.sub, aten.div, aten.exp, aten.sum, prims.inductor_lookup_seed, prims.inductor_random, aten.full, aten.log, aten.argmax]
# [Provenance debug handles] triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0:2
stream0 = get_raw_stream(0)
triton_red_fused_amax_argmax_convert_element_type_div_exp_full_ge_inductor_lookup_seed_inductor_random_log_mul_neg_scalar_tensor_sub_sum_unsqueeze_where_0_0.run(buf2, arg1_1, arg3_1, buf3, buf5, 0, s26, 128256, stream=stream0)
del arg1_1
del arg3_1
del buf2
del buf3
return (reinterpret_tensor(buf5, (s26, 1), (1, 1), 0), )
runner = Runner(partitions=[])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = 256
arg1_1 = rand_strided((256, 128256), (128256, 1), device='cuda:0', dtype=torch.bfloat16)
arg2_1 = 256
arg3_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment