Skip to content

Instantly share code, notes, and snippets.

@ngimel
Created August 1, 2022 05:43
Show Gist options
  • Save ngimel/1a7156a98a5bebab31e15f5b5f6b222d to your computer and use it in GitHub Desktop.
Save ngimel/1a7156a98a5bebab31e15f5b5f6b222d to your computer and use it in GitHub Desktop.
import torch
import math
from torchinductor.compile_fx import compile_fx
import torchdynamo
import torchinductor
torchinductor.config.debug=True
torchinductor.config.triton.cudagraphs=False
def _gelu_python(x):
#return torch.nn.functional.relu(x)
return torch.nn.functional.gelu(x)
#return x * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
class BertIntermediate(torch.nn.Module):
def __init__(self, hidden_size, intermediate_size):
super().__init__()
self.dense0 = torch.nn.Linear(hidden_size, intermediate_size)
self.dense1 = torch.nn.Linear(intermediate_size, hidden_size)
self.intermediate_act_fn = _gelu_python
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense0(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dense1(hidden_states)
return hidden_states
hidden_size = 768
intermediate_size = 768
dtype=torch.half
device="cuda"
mod = BertIntermediate(hidden_size, intermediate_size).to(device, dtype)
bs = 16
seq_length = 512
inp = torch.randn(bs, seq_length, hidden_size, device=device, dtype=dtype, requires_grad=True)
out = mod(inp)
gO = torch.rand_like(out)
out.backward(gO)
torch.cuda.synchronize()
#optimize_ctx = torchdynamo.optimize(compile_fx_training, nopython=True)
optimize_ctx = torchdynamo.optimize("inductor")
with optimize_ctx:
for _ in range(10):
inp.grad=None
out = mod(inp)
out.backward(gO)
FORWARD GRAPH:
opcode name target args kwargs
------------- ------------------ ----------------------- ------------------------------------------------------------------------------------------------------- --------
placeholder primals_1 primals_1 () {}
placeholder primals_2 primals_2 () {}
placeholder primals_3 primals_3 () {}
placeholder primals_4 primals_4 () {}
placeholder primals_5 primals_5 () {}
call_function permute_default aten.permute.default (primals_2, [1, 0]) {}
call_function view_default aten.view.default (primals_5, [8192, 768]) {}
call_function mm_default aten.mm.default (view_default, permute_default) {}
call_function add_tensor aten.add.Tensor (mm_default, primals_1) {}
call_function view_default_1 aten.view.default (add_tensor, [16, 512, 768]) {}
call_function mul_tensor aten.mul.Tensor (view_default_1, 0.5) {}
call_function mul_tensor_1 aten.mul.Tensor (view_default_1, 0.7071067811865476) {}
call_function sign_default aten.sign.default (mul_tensor_1,) {}
call_function abs_default aten.abs.default (mul_tensor_1,) {}
call_function mul_tensor_2 aten.mul.Tensor (abs_default, 0.3275911) {}
call_function add_tensor_1 aten.add.Tensor (mul_tensor_2, 1.0) {}
call_function reciprocal_default aten.reciprocal.default (add_tensor_1,) {}
call_function mul_tensor_3 aten.mul.Tensor (reciprocal_default, 1.0) {}
call_function mul_tensor_4 aten.mul.Tensor (mul_tensor_3, 1.061405429) {}
call_function add_tensor_2 aten.add.Tensor (mul_tensor_4, -1.453152027) {}
call_function mul_tensor_5 aten.mul.Tensor (add_tensor_2, mul_tensor_3) {}
call_function add_tensor_3 aten.add.Tensor (mul_tensor_5, 1.421413741) {}
call_function mul_tensor_6 aten.mul.Tensor (add_tensor_3, mul_tensor_3) {}
call_function add_tensor_4 aten.add.Tensor (mul_tensor_6, -0.284496736) {}
call_function mul_tensor_7 aten.mul.Tensor (add_tensor_4, mul_tensor_3) {}
call_function add_tensor_5 aten.add.Tensor (mul_tensor_7, 0.254829592) {}
call_function mul_tensor_8 aten.mul.Tensor (add_tensor_5, mul_tensor_3) {}
call_function neg_default aten.neg.default (abs_default,) {}
call_function mul_tensor_9 aten.mul.Tensor (neg_default, abs_default) {}
call_function exp_default aten.exp.default (mul_tensor_9,) {}
call_function mul_tensor_10 aten.mul.Tensor (mul_tensor_8, exp_default) {}
get_attr _tensor_constant0 _tensor_constant0 () {}
call_function sub_tensor aten.sub.Tensor (_tensor_constant0, mul_tensor_10) {}
call_function mul_tensor_11 aten.mul.Tensor (sign_default, sub_tensor) {}
call_function add_tensor_6 aten.add.Tensor (mul_tensor_11, 1.0) {}
call_function mul_tensor_12 aten.mul.Tensor (mul_tensor, add_tensor_6) {}
call_function permute_default_1 aten.permute.default (primals_4, [1, 0]) {}
call_function view_default_2 aten.view.default (mul_tensor_12, [8192, 768]) {}
call_function mm_default_1 aten.mm.default (view_default_2, permute_default_1) {}
call_function add_tensor_7 aten.add.Tensor (mm_default_1, primals_3) {}
call_function view_default_3 aten.view.default (add_tensor_7, [16, 512, 768]) {}
call_function permute_default_2 aten.permute.default (permute_default_1, [1, 0]) {}
call_function permute_default_6 aten.permute.default (permute_default, [1, 0]) {}
output output output ([view_default_3, view_default, view_default_1, view_default_2, permute_default_2, permute_default_6],) {}
INFO torchinductor.scheduler: RUN EXTERN buf0
INFO torchinductor.scheduler: blocked names: {MemoryDep(name='buf0', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf1'))], MemoryDep(name='buf1', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf2')), SchedulerNodeBox(value=SchedulerNode(name='buf5'))], StarDep(name='buf2'): [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: blocked deps: {'buf0': [SchedulerNodeBox(value=SchedulerNode(name='buf1'))], 'buf1': [SchedulerNodeBox(value=SchedulerNode(name='buf2')), SchedulerNodeBox(value=SchedulerNode(name='buf5'))], 'buf2': [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], 'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: new fusable_deps: set()
INFO torchinductor.codegen.triton: codegen (6291456, 1)
INFO torchinductor.scheduler: NEW KERNEL
INFO torchinductor.scheduler: RUN buf1
INFO torchinductor.scheduler: RUN buf2
INFO torchinductor.scheduler: RUN buf5
INFO torchinductor.scheduler: blocked names: {StarDep(name='buf2'): [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: blocked deps: {'buf2': [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], 'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf1', index=c0, size=(6291456,)), MemoryDep(name='buf5', index=c0, size=(6291456,)), MemoryDep(name='buf2', index=c0, size=(6291456,))}
INFO torchinductor.scheduler: RUN EXTERN buf3
INFO torchinductor.scheduler: blocked names: {MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: blocked deps: {'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]}
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf1', index=c0, size=(6291456,)), MemoryDep(name='buf5', index=c0, size=(6291456,)), MemoryDep(name='buf2', index=c0, size=(6291456,))}
INFO torchinductor.codegen.triton: codegen (6291456, 1)
INFO torchinductor.scheduler: NEW KERNEL
INFO torchinductor.scheduler: RUN buf4
INFO torchinductor.scheduler: blocked names: {}
INFO torchinductor.scheduler: blocked deps: {}
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf4', index=c0, size=(6291456,))}
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided
from torchinductor.codecache import CppCodeCache, TritonCodeCache
aten = torch.ops.aten
import triton
import triton.language as tl
from torchinductor.triton_ops.autotune import pointwise_heuristics
from torchinductor.triton_ops.autotune import reduction_heuristics
from torchinductor.triton_ops.autotune import grid
@pointwise_heuristics(size_hints=[8388608])
@triton.jit
def kernel0(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex % 768
x0_next = xindex // 768
x1 = x0_next % 8192
x3 = xindex
tmp0 = tl.load(in_ptr0 + x0 + (768*x1), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + x0, xmask).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = 0.5
tmp4 = tmp2 * tmp3
tmp5 = 0.7071067811865476
tmp6 = tmp2 * tmp5
tmp7 = tmp6 < 0
tmp8 = -1
tmp9 = 1
tmp10 = tmp7 | tl.zeros(tmp8.shape, tmp7.dtype) if tmp8.numel > 1 else tmp7
tmp11 = tmp10 | tl.zeros(tmp9.shape, tmp10.dtype) if tmp9.numel > 1 else tmp10
tmp12 = tl.where(tmp11, tmp8, tmp9)
tmp13 = 1.0
tmp14 = tl.abs(tmp6)
tmp15 = 0.3275911
tmp16 = tmp14 * tmp15
tmp17 = tmp16 + tmp13
tmp18 = 1 / tmp17
tmp19 = tmp18 * tmp13
tmp20 = 1.061405429
tmp21 = tmp19 * tmp20
tmp22 = -1.453152027
tmp23 = tmp21 + tmp22
tmp24 = tmp23 * tmp19
tmp25 = 1.421413741
tmp26 = tmp24 + tmp25
tmp27 = tmp26 * tmp19
tmp28 = -0.284496736
tmp29 = tmp27 + tmp28
tmp30 = tmp29 * tmp19
tmp31 = 0.254829592
tmp32 = tmp30 + tmp31
tmp33 = tmp32 * tmp19
tmp34 = -tmp14
tmp35 = tmp34 * tmp14
tmp36 = tl.exp(tmp35)
tmp37 = tmp33 * tmp36
tmp38 = tmp13 - tmp37
tmp39 = tmp12 * tmp38
tmp40 = tmp39 + tmp13
tmp41 = tmp4 * tmp40
tl.store(out_ptr0 + x0 + (768*x1), tmp2, xmask)
tl.store(out_ptr1 + x3, tmp41, xmask)
tl.store(out_ptr2 + x3, tmp41, xmask)
@pointwise_heuristics(size_hints=[8388608])
@triton.jit
def kernel1(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 768
tmp0 = tl.load(in_ptr0 + x2, xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + x0, xmask).to(tl.float32)
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + x2, tmp2, xmask)
def call(primals_1, primals_2, primals_3, primals_4, primals_5):
primals_1_size = primals_1.size()
s0 = primals_1_size[0]
primals_5_size = primals_5.size()
s1 = primals_5_size[0]
s2 = primals_5_size[1]
buf0 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16)
aten.mm.out(as_strided(primals_5, (8192, 768), (768, 1)), as_strided(primals_2, (768, 768), (1, 768)), out=buf0)
buf1 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16)
buf2 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16)
buf5 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16)
kernel0[grid(6291456)](buf0, primals_1, buf1, buf2, buf5, 6291456)
buf3 = buf0; del buf0 # reuse
aten.mm.out(buf2, as_strided(primals_4, (768, 768), (1, 768)), out=buf3)
buf4 = as_strided(buf2, (16, 512, 768), (393216, 768, 1)); del buf2 # reuse
kernel1[grid(6291456)](buf3, primals_3, buf4, 6291456)
return (buf4, as_strided(primals_5, (8192, 768), (768, 1)), as_strided(buf1, (16, 512, 768), (393216, 768, 1)), buf5, as_strided(primals_4, (768, 768), (768, 1)), as_strided(primals_2, (768, 768), (768, 1)), )
if __name__ == "__main__":
from torchdynamo.testing import rand_strided
from benchmarks.microbenchmarks.microbench import print_performance
primals_1 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float16)
primals_2 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float16)
primals_3 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float16)
primals_4 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float16)
primals_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda', dtype=torch.float16)
print_performance(lambda: call(primals_1, primals_2, primals_3, primals_4, primals_5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment