Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save shunting314/3bd8027ab4174e45e97298c742cf79f0 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/3bd8027ab4174e45e97298c742cf79f0 to your computer and use it in GitHub Desktop.
"""
Compile-time auto-tuning block:
import torch
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.select_algorithm import AlgorithmSelectorCache
from torch._inductor.async_compile import AsyncCompile
async_compile = AsyncCompile()
generate_example_value = AlgorithmSelectorCache.generate_example_value
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
get_raw_stream = torch._C._cuda_getCurrentRawStream
# kernel path: /tmp/torchinductor_shunting/42/c42leys5rljfhifk6vdsbsryfxnk2d4b7cbcivnt5jwtzncieeyc.py
# Topologically Sorted Source Nodes: [to, add, pow_1, mean, add_1, rsqrt, mul, to_2, mul_1], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add_9
# add_1 => add_22
# mean => mean
# mul => mul_15
# mul_1 => mul_20
# pow_1 => pow_1
# rsqrt => rsqrt
# to => convert_element_type_2
# to_2 => convert_element_type_4
# Graph fragment:
# %mm : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm]
# %arg4_1 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=arg4_1]
# %buf1 : Tensor "f32[s72, 1][1, s72]cuda:0" = PlaceHolder[target=buf1]
# %arg3_1 : Tensor "bf16[4096][1]cuda:0" = PlaceHolder[target=arg3_1]
# %convert_element_type_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %add_9 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_2, %arg4_1), kwargs = {})
# %pow_1 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%add_9, 2), kwargs = {})
# %mean : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
# %add_22 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
# %rsqrt : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_22,), kwargs = {})
# %mul_15 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_9, %rsqrt), kwargs = {})
# %convert_element_type_4 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_15, torch.bfloat16), kwargs = {})
# %mul_20 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %arg3_1), kwargs = {})
# return %buf1,%mul_20
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 16384, 'r0_': 4096},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'kernel_num_gb': 0.402661376, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 4096
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp9 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp10 + tmp12
tmp14 = 4096.0
tmp15 = (tmp7 / tmp14)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp13 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(out_ptr1 + (r0_1 + 4096*x0), tmp22, r0_mask & xmask)
def get_args():
arg_0 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, 16384, 4096,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.402661376
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/fg/cfgveke5baxhpthp4y7p5c6zky2cpc7pw23by4evkf7o3ouvuzvp.py
# Topologically Sorted Source Nodes: [getitem, silu, getitem_1, mul_2], Original ATen: [aten.slice, aten.silu, aten.mul]
# Source node to ATen node mapping:
# getitem => slice_1
# getitem_1 => slice_2
# mul_2 => mul_32
# silu => convert_element_type_7, convert_element_type_8, mul_27, sigmoid
# Graph fragment:
# %mm_1 : Tensor "bf16[s72, 28672][28672, 1]cuda:0" = PlaceHolder[target=mm_1]
# %slice_1 : Tensor "bf16[s72, 14336][28672, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mm_1, 1, 0, 14336), kwargs = {})
# %convert_element_type_7 : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%slice_1, torch.float32), kwargs = {})
# %sigmoid : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_7,), kwargs = {})
# %mul_27 : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_7, %sigmoid), kwargs = {})
# %convert_element_type_8 : Tensor "bf16[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.bfloat16), kwargs = {})
# %slice_2 : Tensor "bf16[s72, 14336][28672, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mm_1, 1, 14336, 9223372036854775807), kwargs = {})
# %mul_32 : Tensor "bf16[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %slice_2), kwargs = {})
# return %mul_32
triton_poi_fused_mul_silu_slice_2_1 = async_compile.triton('triton_poi_fused_mul_silu_slice_2_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_slice_2_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'kernel_num_gb': 1.409286144, 'kernel_flop': 0},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_slice_2_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = (xindex % 14336)
x1 = xindex // 14336
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + 28672*x1), xmask).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (14336 + x0 + 28672*x1), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(out_ptr0 + (x2), tmp6, xmask)
def get_args():
arg_0 = rand_strided((16384, 28672), (28672, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, 234881024,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_slice_2_1.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_mul_silu_slice_2_1.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 1.409286144
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/hm/chmnjo5ntuw2zivkwmaix3fbqyb7wqq6g66iqgual6qywjjsb4jy.py
# Topologically Sorted Source Nodes: [to, add, to_3, to_1, add_2, pow_2, mean_1, add_3, rsqrt_1, mul_3, to_5, mul_4], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add_9
# add_2 => add_59
# add_3 => add_72
# mean_1 => mean_1
# mul_3 => mul_48
# mul_4 => mul_53
# pow_2 => pow_2
# rsqrt_1 => rsqrt_1
# to => convert_element_type_2
# to_1 => convert_element_type_3
# to_3 => convert_element_type_11
# to_5 => convert_element_type_13
# Graph fragment:
# %mm_2 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_2]
# %mm : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm]
# %arg4_1 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=arg4_1]
# %buf6 : Tensor "f32[s72, 1][1, s72]cuda:0" = PlaceHolder[target=buf6]
# %arg7_1 : Tensor "bf16[4096][1]cuda:0" = PlaceHolder[target=arg7_1]
# %convert_element_type_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %add_9 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_2, %arg4_1), kwargs = {})
# %convert_element_type_11 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_2, torch.float32), kwargs = {})
# %convert_element_type_3 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.bfloat16), kwargs = {})
# %add_59 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_11, %convert_element_type_3), kwargs = {})
# %pow_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%add_59, 2), kwargs = {})
# %mean_1 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
# %add_72 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
# %rsqrt_1 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_72,), kwargs = {})
# %mul_48 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_59, %rsqrt_1), kwargs = {})
# %convert_element_type_13 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_48, torch.bfloat16), kwargs = {})
# %mul_53 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_13, %arg7_1), kwargs = {})
# return %buf6,%mul_53
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 16384, 'r0_': 4096},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 1, 'num_reduction': 1, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'kernel_num_gb': 0.536879104, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 4096
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp1 + tmp8
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(r0_mask & xmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp14 = tl.load(in_out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp16 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp31 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp17 + tmp19
tmp21 = tmp20.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp15 + tmp22
tmp24 = 4096.0
tmp25 = (tmp12 / tmp24)
tmp26 = 1e-05
tmp27 = tmp25 + tmp26
tmp28 = libdevice.rsqrt(tmp27)
tmp29 = tmp23 * tmp28
tmp30 = tmp29.to(tl.float32)
tmp32 = tmp30 * tmp31
tl.store(in_out_ptr0 + (r0_1 + 4096*x0), tmp32, r0_mask & xmask)
def get_args():
arg_0 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, 16384, 4096,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.536879104
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
with torch.cuda._DeviceGuard(0):
stream0 = get_raw_stream(0)
stream0 = get_raw_stream(0)
buf0 = generate_example_value((16384, 4096), (4096, 1), 'cuda:0', torch.bfloat16, 0, (16384, 4096))
arg4_1 = generate_example_value((16384, 4096), (4096, 1), 'cuda:0', torch.bfloat16, 0, (16384, 4096))
arg3_1 = generate_example_value((4096,), (1,), 'cuda:0', torch.bfloat16, 0, (4096,))
buf2 = generate_example_value((16384, 4096), (4096, 1), 'cuda:0', torch.bfloat16, 0, (16384, 4096))
with torch.cuda._DeviceGuard(0):
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.run(buf0, arg4_1, arg3_1, buf2, 16384, 4096, stream=stream0)
del arg3_1, buf2
stream0 = get_raw_stream(0)
buf3 = generate_example_value((16384, 28672), (28672, 1), 'cuda:0', torch.bfloat16, 0, (16384, 28672))
buf4 = generate_example_value((16384, 14336), (14336, 1), 'cuda:0', torch.bfloat16, 0, (16384, 14336))
with torch.cuda._DeviceGuard(0):
triton_poi_fused_mul_silu_slice_2_1.run(buf3, buf4, 234881024, stream=stream0)
del buf3, buf4
stream0 = get_raw_stream(0)
buf7 = generate_example_value((16384, 4096), (4096, 1), 'cuda:0', torch.bfloat16, 0, (16384, 4096))
arg7_1 = generate_example_value((4096,), (1,), 'cuda:0', torch.bfloat16, 0, (4096,))
with torch.cuda._DeviceGuard(0):
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.run(buf7, buf0, arg4_1, arg7_1, 16384, 4096, stream=stream0)
del buf0, arg4_1, buf7, arg7_1
"""
# AOT ID: ['32_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_shunting/42/c42leys5rljfhifk6vdsbsryfxnk2d4b7cbcivnt5jwtzncieeyc.py
# Topologically Sorted Source Nodes: [to, add, pow_1, mean, add_1, rsqrt, mul, to_2, mul_1], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add_9
# add_1 => add_22
# mean => mean
# mul => mul_15
# mul_1 => mul_20
# pow_1 => pow_1
# rsqrt => rsqrt
# to => convert_element_type_2
# to_2 => convert_element_type_4
# Graph fragment:
# %mm : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm]
# %arg4_1 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=arg4_1]
# %buf1 : Tensor "f32[s72, 1][1, s72]cuda:0" = PlaceHolder[target=buf1]
# %arg3_1 : Tensor "bf16[4096][1]cuda:0" = PlaceHolder[target=arg3_1]
# %convert_element_type_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %add_9 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_2, %arg4_1), kwargs = {})
# %pow_1 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%add_9, 2), kwargs = {})
# %mean : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})
# %add_22 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean, 1e-05), kwargs = {})
# %rsqrt : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_22,), kwargs = {})
# %mul_15 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_9, %rsqrt), kwargs = {})
# %convert_element_type_4 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_15, torch.bfloat16), kwargs = {})
# %mul_20 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_4, %arg3_1), kwargs = {})
# return %buf1,%mul_20
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 16384, 'r0_': 4096},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 5, 'num_store': 1, 'num_reduction': 1, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'kernel_num_gb': 0.402661376, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 4096
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp7 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4 * tmp4
tmp6 = tl.broadcast_to(tmp5, [XBLOCK, R0_BLOCK])
tmp8 = _tmp7 + tmp6
_tmp7 = tl.where(r0_mask & xmask, tmp8, _tmp7)
tmp7 = tl.sum(_tmp7, 1)[:, None]
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp9 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp21 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp10 = tmp9.to(tl.float32)
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp10 + tmp12
tmp14 = 4096.0
tmp15 = (tmp7 / tmp14)
tmp16 = 1e-05
tmp17 = tmp15 + tmp16
tmp18 = libdevice.rsqrt(tmp17)
tmp19 = tmp13 * tmp18
tmp20 = tmp19.to(tl.float32)
tmp22 = tmp20 * tmp21
tl.store(out_ptr1 + (r0_1 + 4096*x0), tmp22, r0_mask & xmask)
def get_args():
arg_0 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, 16384, 4096,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.402661376
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/fg/cfgveke5baxhpthp4y7p5c6zky2cpc7pw23by4evkf7o3ouvuzvp.py
# Topologically Sorted Source Nodes: [getitem, silu, getitem_1, mul_2], Original ATen: [aten.slice, aten.silu, aten.mul]
# Source node to ATen node mapping:
# getitem => slice_1
# getitem_1 => slice_2
# mul_2 => mul_32
# silu => convert_element_type_7, convert_element_type_8, mul_27, sigmoid
# Graph fragment:
# %mm_1 : Tensor "bf16[s72, 28672][28672, 1]cuda:0" = PlaceHolder[target=mm_1]
# %slice_1 : Tensor "bf16[s72, 14336][28672, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mm_1, 1, 0, 14336), kwargs = {})
# %convert_element_type_7 : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%slice_1, torch.float32), kwargs = {})
# %sigmoid : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.sigmoid.default](args = (%convert_element_type_7,), kwargs = {})
# %mul_27 : Tensor "f32[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_7, %sigmoid), kwargs = {})
# %convert_element_type_8 : Tensor "bf16[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_27, torch.bfloat16), kwargs = {})
# %slice_2 : Tensor "bf16[s72, 14336][28672, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.slice.Tensor](args = (%mm_1, 1, 14336, 9223372036854775807), kwargs = {})
# %mul_32 : Tensor "bf16[s72, 14336][14336, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_8, %slice_2), kwargs = {})
# return %mul_32
triton_poi_fused_mul_silu_slice_2_1 = async_compile.triton('triton_poi_fused_mul_silu_slice_2_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_silu_slice_2_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'kernel_num_gb': 1.409286144, 'kernel_flop': 0},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_silu_slice_2_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = (xindex % 14336)
x1 = xindex // 14336
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + 28672*x1), xmask).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (14336 + x0 + 28672*x1), xmask).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.sigmoid(tmp1)
tmp3 = tmp1 * tmp2
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tl.store(out_ptr0 + (x2), tmp6, xmask)
def get_args():
arg_0 = rand_strided((16384, 28672), (28672, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, 234881024,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_slice_2_1.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_poi_fused_mul_silu_slice_2_1.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 1.409286144
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/hm/chmnjo5ntuw2zivkwmaix3fbqyb7wqq6g66iqgual6qywjjsb4jy.py
# Topologically Sorted Source Nodes: [to, add, to_3, to_1, add_2, pow_2, mean_1, add_3, rsqrt_1, mul_3, to_5, mul_4], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
# Source node to ATen node mapping:
# add => add_9
# add_2 => add_59
# add_3 => add_72
# mean_1 => mean_1
# mul_3 => mul_48
# mul_4 => mul_53
# pow_2 => pow_2
# rsqrt_1 => rsqrt_1
# to => convert_element_type_2
# to_1 => convert_element_type_3
# to_3 => convert_element_type_11
# to_5 => convert_element_type_13
# Graph fragment:
# %mm_2 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm_2]
# %mm : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=mm]
# %arg4_1 : Tensor "bf16[s72, 4096][4096, 1]cuda:0" = PlaceHolder[target=arg4_1]
# %buf6 : Tensor "f32[s72, 1][1, s72]cuda:0" = PlaceHolder[target=buf6]
# %arg7_1 : Tensor "bf16[4096][1]cuda:0" = PlaceHolder[target=arg7_1]
# %convert_element_type_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm, torch.float32), kwargs = {})
# %add_9 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_2, %arg4_1), kwargs = {})
# %convert_element_type_11 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_2, torch.float32), kwargs = {})
# %convert_element_type_3 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_9, torch.bfloat16), kwargs = {})
# %add_59 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_11, %convert_element_type_3), kwargs = {})
# %pow_2 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%add_59, 2), kwargs = {})
# %mean_1 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_2, [-1], True), kwargs = {})
# %add_72 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mean_1, 1e-05), kwargs = {})
# %rsqrt_1 : Tensor "f32[s72, 1][1, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_72,), kwargs = {})
# %mul_48 : Tensor "f32[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_59, %rsqrt_1), kwargs = {})
# %convert_element_type_13 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_48, torch.bfloat16), kwargs = {})
# %mul_53 : Tensor "bf16[s72, 4096][4096, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_13, %arg7_1), kwargs = {})
# return %buf6,%mul_53
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 16384, 'r0_': 4096},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=148, cc=100, major=10, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 7, 'num_store': 1, 'num_reduction': 1, 'backend_hash': '0BE79B16E554042AA7B1EB4102B8EA61128454EAFD0C7CABEEF1703B1EAEF73E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'add_persistent_rblock': True, 'kernel_num_gb': 0.536879104, 'kernel_flop': 0}
)
@triton.jit
def triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 4096
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp12 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp4 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp3 + tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp1 + tmp8
tmp10 = tmp9 * tmp9
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp13 = _tmp12 + tmp11
_tmp12 = tl.where(r0_mask & xmask, tmp13, _tmp12)
tmp12 = tl.sum(_tmp12, 1)[:, None]
for r0_offset in tl.range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp14 = tl.load(in_out_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp16 = tl.load(in_ptr0 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (r0_1 + 4096*x0), r0_mask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp31 = tl.load(in_ptr2 + (r0_1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp17 + tmp19
tmp21 = tmp20.to(tl.float32)
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp15 + tmp22
tmp24 = 4096.0
tmp25 = (tmp12 / tmp24)
tmp26 = 1e-05
tmp27 = tmp25 + tmp26
tmp28 = libdevice.rsqrt(tmp27)
tmp29 = tmp23 * tmp28
tmp30 = tmp29.to(tl.float32)
tmp32 = tmp30 * tmp31
tl.store(in_out_ptr0 + (r0_1 + 4096*x0), tmp32, r0_mask & xmask)
def get_args():
arg_0 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_1 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_2 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg_3 = rand_strided((4096,), (1,), device='cuda:0', dtype=torch.bfloat16)
return arg_0, arg_1, arg_2, arg_3, 16384, 4096,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark(lambda: call(args), device='cuda', rep=40)
num_gb = 0.536879104
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 = args
args.clear()
s72 = arg1_1
assert_size_stride(arg0_1, (s72, 32, 128), (4096, 128, 1))
assert_size_stride(arg2_1, (4096, 4096), (4096, 1))
assert_size_stride(arg3_1, (4096, ), (1, ))
assert_size_stride(arg4_1, (s72, 4096), (4096, 1))
assert_size_stride(arg5_1, (28672, 4096), (4096, 1))
assert_size_stride(arg6_1, (4096, 14336), (14336, 1))
assert_size_stride(arg7_1, (4096, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((s72, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [view, linear], Original ATen: [aten.view, aten.t, aten.mm]
extern_kernels.mm(reinterpret_tensor(arg0_1, (s72, 4096), (4096, 1), 0), reinterpret_tensor(arg2_1, (4096, 4096), (1, 4096), 0), out=buf0)
del arg0_1
del arg2_1
buf2 = empty_strided_cuda((s72, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [to, add, pow_1, mean, add_1, rsqrt, mul, to_2, mul_1], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_0.run(buf0, arg4_1, arg3_1, buf2, s72, 4096, stream=stream0)
del arg3_1
buf3 = empty_strided_cuda((s72, 28672), (28672, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [to, add, pow_1, mean, add_1, rsqrt, mul, to_2, mul_1, linear_1], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul, aten.t, aten.mm]
extern_kernels.mm(buf2, reinterpret_tensor(arg5_1, (4096, 28672), (1, 4096), 0), out=buf3)
del arg5_1
del buf2
buf4 = empty_strided_cuda((s72, 14336), (14336, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [getitem, silu, getitem_1, mul_2], Original ATen: [aten.slice, aten.silu, aten.mul]
triton_poi_fused_mul_silu_slice_2_1_xnumel = 14336*s72
stream0 = get_raw_stream(0)
triton_poi_fused_mul_silu_slice_2_1.run(buf3, buf4, triton_poi_fused_mul_silu_slice_2_1_xnumel, stream=stream0)
del buf3
buf5 = empty_strided_cuda((s72, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [getitem, silu, getitem_1, mul_2, linear_2], Original ATen: [aten.slice, aten.silu, aten.mul, aten.t, aten.mm]
extern_kernels.mm(buf4, reinterpret_tensor(arg6_1, (14336, 4096), (1, 14336), 0), out=buf5)
del arg6_1
del buf4
buf7 = buf5; del buf5 # reuse
# Topologically Sorted Source Nodes: [to, add, to_3, to_1, add_2, pow_2, mean_1, add_3, rsqrt_1, mul_3, to_5, mul_4], Original ATen: [aten._to_copy, aten.add, aten.pow, aten.mean, aten.rsqrt, aten.mul]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_2_2.run(buf7, buf0, arg4_1, arg7_1, s72, 4096, stream=stream0)
del arg4_1
del arg7_1
del buf0
return (buf7, )
runner = Runner(partitions=[])
call = runner.call
recursively_apply_fns = runner.recursively_apply_fns
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((16384, 32, 128), (4096, 128, 1), device='cuda:0', dtype=torch.bfloat16)
arg1_1 = 16384
arg2_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg3_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
arg4_1 = rand_strided((16384, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg5_1 = rand_strided((28672, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
arg6_1 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
arg7_1 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment