Skip to content

Instantly share code, notes, and snippets.

@FindHao
Created June 7, 2023 20:49
Show Gist options
  • Save FindHao/39dfa9dc28a7a7f65916f16257fadce7 to your computer and use it in GitHub Desktop.
Save FindHao/39dfa9dc28a7a7f65916f16257fadce7 to your computer and use it in GitHub Desktop.
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_yhao24/xb/cxb65b7rgoqlag5gww5pcmeomprb5odoghkdbrxvopsuuyn6ohhi.py
# Original ATen:
triton_poi_fused_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(size_hints=[1024], filename=__file__, meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, load_seed_offset, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
tl.store(out_ptr0 + (x0), tmp2, xmask)
''')
import triton
import triton.language as tl
from torch._inductor.triton_heuristics import grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
# kernel path: /tmp/torchinductor_yhao24/d2/cd2ayffmuyvv5aoobwqfnqwhctwg42bu25tzkns24sl4papbfbqk.py
# Original ATen:
triton_poi_fused_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.ir import TileHint
from torch._inductor.triton_heuristics import pointwise
from torch._inductor.utils import instance_descriptor
from torch._inductor import triton_helpers
@pointwise(size_hints=[1024], filename=__file__, meta={'signature': {0: '*i64', 1: '*i64', 2: 'i32', 3: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 3), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, out_ptr0, load_seed_offset, xnumel, XBLOCK : tl.constexpr):
xnumel = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 11)
tl.store(out_ptr0 + (x0), tmp2, xmask)
''')
async_compile.wait(globals())
del async_compile
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((3, ), (1, ), device='cuda', dtype=torch.int64)
aten.randint.low_out(-9223372036854775808, 9223372036854775807, [3], out=buf0)
buf1 = empty_strided((1024, ), (1, ), device='cuda', dtype=torch.int64)
stream0 = get_cuda_stream(0)
stream0_raw = torch._C._cuda_getCurrentStream(0)
stream1_raw = torch.cuda.Stream()
stream1 = stream1_raw.stream_id
triton_poi_fused_0.run(buf0, buf1, 0, 1024, grid=grid(1024), stream=stream0)
buf2 = empty_strided((1024, ), (1, ), device='cuda', dtype=torch.int64)
triton_poi_fused_1.run(buf0, buf2, 1, 1024, grid=grid(1024), stream=stream1)
buf3 = empty_strided((1024, ), (1, ), device='cuda', dtype=torch.int64)
triton_poi_fused_0.run(buf0, buf3, 2, 1024, grid=grid(1024), stream=stream0)
return (buf1, buf2, buf3, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
return print_performance(lambda: call([]), times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.utils import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment