Created
August 1, 2022 05:43
-
-
Save ngimel/1a7156a98a5bebab31e15f5b5f6b222d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
from torchinductor.compile_fx import compile_fx | |
import torchdynamo | |
import torchinductor | |
torchinductor.config.debug=True | |
torchinductor.config.triton.cudagraphs=False | |
def _gelu_python(x): | |
#return torch.nn.functional.relu(x) | |
return torch.nn.functional.gelu(x) | |
#return x * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) | |
class BertIntermediate(torch.nn.Module): | |
def __init__(self, hidden_size, intermediate_size): | |
super().__init__() | |
self.dense0 = torch.nn.Linear(hidden_size, intermediate_size) | |
self.dense1 = torch.nn.Linear(intermediate_size, hidden_size) | |
self.intermediate_act_fn = _gelu_python | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
hidden_states = self.dense0(hidden_states) | |
hidden_states = self.intermediate_act_fn(hidden_states) | |
hidden_states = self.dense1(hidden_states) | |
return hidden_states | |
hidden_size = 768 | |
intermediate_size = 768 | |
dtype=torch.half | |
device="cuda" | |
mod = BertIntermediate(hidden_size, intermediate_size).to(device, dtype) | |
bs = 16 | |
seq_length = 512 | |
inp = torch.randn(bs, seq_length, hidden_size, device=device, dtype=dtype, requires_grad=True) | |
out = mod(inp) | |
gO = torch.rand_like(out) | |
out.backward(gO) | |
torch.cuda.synchronize() | |
#optimize_ctx = torchdynamo.optimize(compile_fx_training, nopython=True) | |
optimize_ctx = torchdynamo.optimize("inductor") | |
with optimize_ctx: | |
for _ in range(10): | |
inp.grad=None | |
out = mod(inp) | |
out.backward(gO) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FORWARD GRAPH: | |
opcode name target args kwargs | |
------------- ------------------ ----------------------- ------------------------------------------------------------------------------------------------------- -------- | |
placeholder primals_1 primals_1 () {} | |
placeholder primals_2 primals_2 () {} | |
placeholder primals_3 primals_3 () {} | |
placeholder primals_4 primals_4 () {} | |
placeholder primals_5 primals_5 () {} | |
call_function permute_default aten.permute.default (primals_2, [1, 0]) {} | |
call_function view_default aten.view.default (primals_5, [8192, 768]) {} | |
call_function mm_default aten.mm.default (view_default, permute_default) {} | |
call_function add_tensor aten.add.Tensor (mm_default, primals_1) {} | |
call_function view_default_1 aten.view.default (add_tensor, [16, 512, 768]) {} | |
call_function mul_tensor aten.mul.Tensor (view_default_1, 0.5) {} | |
call_function mul_tensor_1 aten.mul.Tensor (view_default_1, 0.7071067811865476) {} | |
call_function sign_default aten.sign.default (mul_tensor_1,) {} | |
call_function abs_default aten.abs.default (mul_tensor_1,) {} | |
call_function mul_tensor_2 aten.mul.Tensor (abs_default, 0.3275911) {} | |
call_function add_tensor_1 aten.add.Tensor (mul_tensor_2, 1.0) {} | |
call_function reciprocal_default aten.reciprocal.default (add_tensor_1,) {} | |
call_function mul_tensor_3 aten.mul.Tensor (reciprocal_default, 1.0) {} | |
call_function mul_tensor_4 aten.mul.Tensor (mul_tensor_3, 1.061405429) {} | |
call_function add_tensor_2 aten.add.Tensor (mul_tensor_4, -1.453152027) {} | |
call_function mul_tensor_5 aten.mul.Tensor (add_tensor_2, mul_tensor_3) {} | |
call_function add_tensor_3 aten.add.Tensor (mul_tensor_5, 1.421413741) {} | |
call_function mul_tensor_6 aten.mul.Tensor (add_tensor_3, mul_tensor_3) {} | |
call_function add_tensor_4 aten.add.Tensor (mul_tensor_6, -0.284496736) {} | |
call_function mul_tensor_7 aten.mul.Tensor (add_tensor_4, mul_tensor_3) {} | |
call_function add_tensor_5 aten.add.Tensor (mul_tensor_7, 0.254829592) {} | |
call_function mul_tensor_8 aten.mul.Tensor (add_tensor_5, mul_tensor_3) {} | |
call_function neg_default aten.neg.default (abs_default,) {} | |
call_function mul_tensor_9 aten.mul.Tensor (neg_default, abs_default) {} | |
call_function exp_default aten.exp.default (mul_tensor_9,) {} | |
call_function mul_tensor_10 aten.mul.Tensor (mul_tensor_8, exp_default) {} | |
get_attr _tensor_constant0 _tensor_constant0 () {} | |
call_function sub_tensor aten.sub.Tensor (_tensor_constant0, mul_tensor_10) {} | |
call_function mul_tensor_11 aten.mul.Tensor (sign_default, sub_tensor) {} | |
call_function add_tensor_6 aten.add.Tensor (mul_tensor_11, 1.0) {} | |
call_function mul_tensor_12 aten.mul.Tensor (mul_tensor, add_tensor_6) {} | |
call_function permute_default_1 aten.permute.default (primals_4, [1, 0]) {} | |
call_function view_default_2 aten.view.default (mul_tensor_12, [8192, 768]) {} | |
call_function mm_default_1 aten.mm.default (view_default_2, permute_default_1) {} | |
call_function add_tensor_7 aten.add.Tensor (mm_default_1, primals_3) {} | |
call_function view_default_3 aten.view.default (add_tensor_7, [16, 512, 768]) {} | |
call_function permute_default_2 aten.permute.default (permute_default_1, [1, 0]) {} | |
call_function permute_default_6 aten.permute.default (permute_default, [1, 0]) {} | |
output output output ([view_default_3, view_default, view_default_1, view_default_2, permute_default_2, permute_default_6],) {} | |
INFO torchinductor.scheduler: RUN EXTERN buf0 | |
INFO torchinductor.scheduler: blocked names: {MemoryDep(name='buf0', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf1'))], MemoryDep(name='buf1', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf2')), SchedulerNodeBox(value=SchedulerNode(name='buf5'))], StarDep(name='buf2'): [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: blocked deps: {'buf0': [SchedulerNodeBox(value=SchedulerNode(name='buf1'))], 'buf1': [SchedulerNodeBox(value=SchedulerNode(name='buf2')), SchedulerNodeBox(value=SchedulerNode(name='buf5'))], 'buf2': [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], 'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: new fusable_deps: set() | |
INFO torchinductor.codegen.triton: codegen (6291456, 1) | |
INFO torchinductor.scheduler: NEW KERNEL | |
INFO torchinductor.scheduler: RUN buf1 | |
INFO torchinductor.scheduler: RUN buf2 | |
INFO torchinductor.scheduler: RUN buf5 | |
INFO torchinductor.scheduler: blocked names: {StarDep(name='buf2'): [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: blocked deps: {'buf2': [SchedulerNodeBox(value=ExternKernelSchedulerNode(name='buf3'))], 'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf1', index=c0, size=(6291456,)), MemoryDep(name='buf5', index=c0, size=(6291456,)), MemoryDep(name='buf2', index=c0, size=(6291456,))} | |
INFO torchinductor.scheduler: RUN EXTERN buf3 | |
INFO torchinductor.scheduler: blocked names: {MemoryDep(name='buf3', index=c0, size=(6291456,)): [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: blocked deps: {'buf3': [SchedulerNodeBox(value=SchedulerNode(name='buf4'))]} | |
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf1', index=c0, size=(6291456,)), MemoryDep(name='buf5', index=c0, size=(6291456,)), MemoryDep(name='buf2', index=c0, size=(6291456,))} | |
INFO torchinductor.codegen.triton: codegen (6291456, 1) | |
INFO torchinductor.scheduler: NEW KERNEL | |
INFO torchinductor.scheduler: RUN buf4 | |
INFO torchinductor.scheduler: blocked names: {} | |
INFO torchinductor.scheduler: blocked deps: {} | |
INFO torchinductor.scheduler: new fusable_deps: {MemoryDep(name='buf4', index=c0, size=(6291456,))} | |
from ctypes import c_void_p, c_long | |
import torch | |
import random | |
from torch import empty_strided, as_strided | |
from torchinductor.codecache import CppCodeCache, TritonCodeCache | |
aten = torch.ops.aten | |
import triton | |
import triton.language as tl | |
from torchinductor.triton_ops.autotune import pointwise_heuristics | |
from torchinductor.triton_ops.autotune import reduction_heuristics | |
from torchinductor.triton_ops.autotune import grid | |
@pointwise_heuristics(size_hints=[8388608]) | |
@triton.jit | |
def kernel0(in_ptr0, in_ptr1, out_ptr0, out_ptr1, out_ptr2, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x0 = xindex % 768 | |
x0_next = xindex // 768 | |
x1 = x0_next % 8192 | |
x3 = xindex | |
tmp0 = tl.load(in_ptr0 + x0 + (768*x1), xmask).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + x0, xmask).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp3 = 0.5 | |
tmp4 = tmp2 * tmp3 | |
tmp5 = 0.7071067811865476 | |
tmp6 = tmp2 * tmp5 | |
tmp7 = tmp6 < 0 | |
tmp8 = -1 | |
tmp9 = 1 | |
tmp10 = tmp7 | tl.zeros(tmp8.shape, tmp7.dtype) if tmp8.numel > 1 else tmp7 | |
tmp11 = tmp10 | tl.zeros(tmp9.shape, tmp10.dtype) if tmp9.numel > 1 else tmp10 | |
tmp12 = tl.where(tmp11, tmp8, tmp9) | |
tmp13 = 1.0 | |
tmp14 = tl.abs(tmp6) | |
tmp15 = 0.3275911 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = tmp16 + tmp13 | |
tmp18 = 1 / tmp17 | |
tmp19 = tmp18 * tmp13 | |
tmp20 = 1.061405429 | |
tmp21 = tmp19 * tmp20 | |
tmp22 = -1.453152027 | |
tmp23 = tmp21 + tmp22 | |
tmp24 = tmp23 * tmp19 | |
tmp25 = 1.421413741 | |
tmp26 = tmp24 + tmp25 | |
tmp27 = tmp26 * tmp19 | |
tmp28 = -0.284496736 | |
tmp29 = tmp27 + tmp28 | |
tmp30 = tmp29 * tmp19 | |
tmp31 = 0.254829592 | |
tmp32 = tmp30 + tmp31 | |
tmp33 = tmp32 * tmp19 | |
tmp34 = -tmp14 | |
tmp35 = tmp34 * tmp14 | |
tmp36 = tl.exp(tmp35) | |
tmp37 = tmp33 * tmp36 | |
tmp38 = tmp13 - tmp37 | |
tmp39 = tmp12 * tmp38 | |
tmp40 = tmp39 + tmp13 | |
tmp41 = tmp4 * tmp40 | |
tl.store(out_ptr0 + x0 + (768*x1), tmp2, xmask) | |
tl.store(out_ptr1 + x3, tmp41, xmask) | |
tl.store(out_ptr2 + x3, tmp41, xmask) | |
@pointwise_heuristics(size_hints=[8388608]) | |
@triton.jit | |
def kernel1(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK]) | |
xmask = xindex < xnumel | |
x2 = xindex | |
x0 = xindex % 768 | |
tmp0 = tl.load(in_ptr0 + x2, xmask).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + x0, xmask).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tl.store(out_ptr0 + x2, tmp2, xmask) | |
def call(primals_1, primals_2, primals_3, primals_4, primals_5): | |
primals_1_size = primals_1.size() | |
s0 = primals_1_size[0] | |
primals_5_size = primals_5.size() | |
s1 = primals_5_size[0] | |
s2 = primals_5_size[1] | |
buf0 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16) | |
aten.mm.out(as_strided(primals_5, (8192, 768), (768, 1)), as_strided(primals_2, (768, 768), (1, 768)), out=buf0) | |
buf1 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16) | |
buf2 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16) | |
buf5 = empty_strided((8192, 768), (768, 1), device='cuda', dtype=torch.float16) | |
kernel0[grid(6291456)](buf0, primals_1, buf1, buf2, buf5, 6291456) | |
buf3 = buf0; del buf0 # reuse | |
aten.mm.out(buf2, as_strided(primals_4, (768, 768), (1, 768)), out=buf3) | |
buf4 = as_strided(buf2, (16, 512, 768), (393216, 768, 1)); del buf2 # reuse | |
kernel1[grid(6291456)](buf3, primals_3, buf4, 6291456) | |
return (buf4, as_strided(primals_5, (8192, 768), (768, 1)), as_strided(buf1, (16, 512, 768), (393216, 768, 1)), buf5, as_strided(primals_4, (768, 768), (768, 1)), as_strided(primals_2, (768, 768), (768, 1)), ) | |
if __name__ == "__main__": | |
from torchdynamo.testing import rand_strided | |
from benchmarks.microbenchmarks.microbench import print_performance | |
primals_1 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float16) | |
primals_2 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float16) | |
primals_3 = rand_strided((768, ), (1, ), device='cuda', dtype=torch.float16) | |
primals_4 = rand_strided((768, 768), (768, 1), device='cuda', dtype=torch.float16) | |
primals_5 = rand_strided((16, 512, 768), (393216, 768, 1), device='cuda', dtype=torch.float16) | |
print_performance(lambda: call(primals_1, primals_2, primals_3, primals_4, primals_5)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment