Skip to content

Instantly share code, notes, and snippets.

@YuchenJin
Last active March 13, 2022 07:32
Show Gist options
  • Save YuchenJin/14cb5c8791d47e98203aba32b130d8fc to your computer and use it in GitHub Desktop.
Save YuchenJin/14cb5c8791d47e98203aba32b130d8fc to your computer and use it in GitHub Desktop.
from tvm.script import relax as R
from tvm import relax
resnet_mod_text = """
@tvm.script.ir_module
class Module:
@tir.prim_func
def batch_norm2(rxplaceholder_5: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 56, 56), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm2", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 56, 56):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 56, 56):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(64):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(64, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(64):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(64, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def conv2d_nchw1(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw1", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw(rxplaceholder_2: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_3: tir.Buffer[(64, 3, 7, 7), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 3, 230, 230], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 3, 230, 230):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 3 and i2_2 < 227 and i3_2 >= 3 and i3_2 < 227, rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3], tir.float32(0), dtype="float32")
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 112, 112, 3, 7, 7):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def add(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(1, 256, 56, 56), "float32"], T_add_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "add", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
tir.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]
@tir.prim_func
def conv2d_nchw5(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(128, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw5", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 256, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def relu(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], T_relu_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def batch_norm5(rxplaceholder_5: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 28, 28), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm5", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 28, 28):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 28, 28):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(512):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(512, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(512):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(512, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def relu6(rxplaceholder_1: tir.Buffer[(1, 1024, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu6", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def conv2d_nchw18(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(2048, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw18", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 1024, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def dense(rxplaceholder_2: tir.Buffer[(1, 2048), "float32"], rxplaceholder_3: tir.Buffer[(1000, 2048), "float32"], T_matmul_NT_1: tir.Buffer[(1, 1000), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "dense", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2 in tir.grid(1, 1000, 2048):
with tir.block("T_matmul_NT"):
i, j, k = tir.axis.remap("SSR", [i0, i1, i2])
tir.reads(T_matmul_NT_1[i, j], rxplaceholder_2[i, k], rxplaceholder_3[j, k])
tir.writes(T_matmul_NT_1[i, j])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
T_matmul_NT_1[i, j] = tir.float32(0)
T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + rxplaceholder_2[i, k] * rxplaceholder_3[j, k]
@tir.prim_func
def conv2d_nchw15(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(512, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw15", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 1024, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw13(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1024, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw13", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 512, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def add3(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(1, 2048, 7, 7), "float32"], T_add_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "add3", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
tir.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]
@tir.prim_func
def bias_add(rxplaceholder_2: tir.Buffer[(1, 1000), "float32"], rxplaceholder_3: tir.Buffer[(1000,), "float32"], T_add_1: tir.Buffer[(1, 1000), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "bias_add", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1 in tir.grid(1, 1000):
with tir.block("T_add"):
ax0, ax1 = tir.axis.remap("SS", [i0, i1])
tir.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])
tir.writes(T_add_1[ax0, ax1])
T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]
@tir.prim_func
def batch_norm(rxplaceholder_5: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_6: tir.Buffer[(3,), "float32"], rxplaceholder_7: tir.Buffer[(3,), "float32"], rxplaceholder_8: tir.Buffer[(3,), "float32"], rxplaceholder_9: tir.Buffer[(3,), "float32"], T_add_2: tir.Buffer[(1, 3, 224, 224), "float32"], T_multiply_2: tir.Buffer[(3,), "float32"], T_multiply_3: tir.Buffer[(3,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32")
T_reshape_4 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 3])
tir.writes(T_reshape_3[ax0, ax1, ax2, ax3])
T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
for i0, i1, i2, i3 in tir.grid(1, 3, 224, 224):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_3[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_3[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 3])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_4[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_4[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 224, 224):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 3, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 3])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 3, 224, 224):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_5[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T_reshape_5[ax0, ax1, 0, 0]
for i0_6 in tir.serial(3):
with tir.block("T_multiply"):
ax0 = tir.axis.spatial(3, i0_6)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_2[ax0])
T_multiply_2[ax0] = rxplaceholder_8[ax0]
for i0_7 in tir.serial(3):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(3, i0_7)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def relu5(rxplaceholder_1: tir.Buffer[(1, 256, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu5", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def conv2d_nchw8(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(512, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw8", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 256, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def relu1(rxplaceholder_1: tir.Buffer[(1, 64, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu1", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def conv2d_nchw10(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(256, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw10", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 512, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw19(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 2048, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw19", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 2048, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw6(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 128, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw6", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 128, 30, 30], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 128, 30, 30):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 29 and i3_2 >= 1 and i3_2 < 29, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32")
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 128, 3, 3):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def batch_norm3(rxplaceholder_5: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 56, 56), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm3", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 56, 56):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 56, 56):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(256):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(256, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(256):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(256, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def batch_norm4(rxplaceholder_5: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(128,), "float32"], rxplaceholder_7: tir.Buffer[(128,), "float32"], rxplaceholder_8: tir.Buffer[(128,), "float32"], rxplaceholder_9: tir.Buffer[(128,), "float32"], T_add_2: tir.Buffer[(1, 128, 28, 28), "float32"], T_multiply_3: tir.Buffer[(128,), "float32"], T_multiply_4: tir.Buffer[(128,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm4", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 128])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 128])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 128, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 128])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 128, 28, 28):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 128, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 128])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 128, 28, 28):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(128):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(128, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(128):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(128, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def conv2d_nchw3(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(256, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw3", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 56, 56, 64, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw12(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1024, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw12", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 256, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw14(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw14", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 1024, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw16(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 512, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw16", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 512, 9, 9], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 9, 9):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 8 and i3_2 >= 1 and i3_2 < 8, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32")
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 512, 3, 3):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def batch_norm9(rxplaceholder_5: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(2048,), "float32"], rxplaceholder_7: tir.Buffer[(2048,), "float32"], rxplaceholder_8: tir.Buffer[(2048,), "float32"], rxplaceholder_9: tir.Buffer[(2048,), "float32"], T_add_2: tir.Buffer[(1, 2048, 7, 7), "float32"], T_multiply_3: tir.Buffer[(2048,), "float32"], T_multiply_4: tir.Buffer[(2048,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm9", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 2048])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 2048])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 2048, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 2048])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 2048, 7, 7):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 2048, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 2048])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 2048, 7, 7):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(2048):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(2048, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(2048):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(2048, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def batch_norm1(rxplaceholder_5: tir.Buffer[(1, 64, 112, 112), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 112, 112), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm1", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 112, 112):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 112, 112):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 112, 112):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(64):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(64, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(64):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(64, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@relax.function
def main(data: Tensor[(1, 3, 224, 224), "float32"], bn_data_gamma: Tensor[(3,), "float32"], bn_data_beta: Tensor[(3,), "float32"], bn_data_moving_mean: Tensor[(3,), "float32"], bn_data_moving_var: Tensor[(3,), "float32"], conv0_weight: Tensor[(64, 3, 7, 7), "float32"], bn0_gamma: Tensor[(64,), "float32"], bn0_beta: Tensor[(64,), "float32"], bn0_moving_mean: Tensor[(64,), "float32"], bn0_moving_var: Tensor[(64,), "float32"], stage1_unit1_bn1_gamma: Tensor[(64,), "float32"], stage1_unit1_bn1_beta: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv1_weight: Tensor[(64, 64, 1, 1), "float32"], stage1_unit1_bn2_gamma: Tensor[(64,), "float32"], stage1_unit1_bn2_beta: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit1_bn3_gamma: Tensor[(64,), "float32"], stage1_unit1_bn3_beta: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit1_sc_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit2_bn1_gamma: Tensor[(256,), "float32"], stage1_unit2_bn1_beta: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit2_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit2_bn2_gamma: Tensor[(64,), "float32"], stage1_unit2_bn2_beta: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit2_bn3_gamma: Tensor[(64,), "float32"], stage1_unit2_bn3_beta: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit3_bn1_gamma: Tensor[(256,), "float32"], stage1_unit3_bn1_beta: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit3_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit3_bn2_gamma: Tensor[(64,), "float32"], stage1_unit3_bn2_beta: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit3_bn3_gamma: Tensor[(64,), "float32"], stage1_unit3_bn3_beta: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage2_unit1_bn1_gamma: Tensor[(256,), "float32"], stage2_unit1_bn1_beta: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_mean: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_var: Tensor[(256,), "float32"], stage2_unit1_conv1_weight: Tensor[(128, 256, 1, 1), "float32"], stage2_unit1_bn2_gamma: Tensor[(128,), "float32"], stage2_unit1_bn2_beta: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit1_bn3_gamma: Tensor[(128,), "float32"], stage2_unit1_bn3_beta: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit1_sc_weight: Tensor[(512, 256, 1, 1), "float32"], stage2_unit2_bn1_gamma: Tensor[(512,), "float32"], stage2_unit2_bn1_beta: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit2_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit2_bn2_gamma: Tensor[(128,), "float32"], stage2_unit2_bn2_beta: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit2_bn3_gamma: Tensor[(128,), "float32"], stage2_unit2_bn3_beta: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit3_bn1_gamma: Tensor[(512,), "float32"], stage2_unit3_bn1_beta: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit3_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit3_bn2_gamma: Tensor[(128,), "float32"], stage2_unit3_bn2_beta: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit3_bn3_gamma: Tensor[(128,), "float32"], stage2_unit3_bn3_beta: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit4_bn1_gamma: Tensor[(512,), "float32"], stage2_unit4_bn1_beta: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit4_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit4_bn2_gamma: Tensor[(128,), "float32"], stage2_unit4_bn2_beta: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit4_bn3_gamma: Tensor[(128,), "float32"], stage2_unit4_bn3_beta: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage3_unit1_bn1_gamma: Tensor[(512,), "float32"], stage3_unit1_bn1_beta: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_mean: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_var: Tensor[(512,), "float32"], stage3_unit1_conv1_weight: Tensor[(256, 512, 1, 1), "float32"], stage3_unit1_bn2_gamma: Tensor[(256,), "float32"], stage3_unit1_bn2_beta: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit1_bn3_gamma: Tensor[(256,), "float32"], stage3_unit1_bn3_beta: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit1_sc_weight: Tensor[(1024, 512, 1, 1), "float32"], stage3_unit2_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit2_bn1_beta: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit2_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit2_bn2_gamma: Tensor[(256,), "float32"], stage3_unit2_bn2_beta: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit2_bn3_gamma: Tensor[(256,), "float32"], stage3_unit2_bn3_beta: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit3_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit3_bn1_beta: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit3_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit3_bn2_gamma: Tensor[(256,), "float32"], stage3_unit3_bn2_beta: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit3_bn3_gamma: Tensor[(256,), "float32"], stage3_unit3_bn3_beta: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit4_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit4_bn1_beta: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit4_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit4_bn2_gamma: Tensor[(256,), "float32"], stage3_unit4_bn2_beta: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit4_bn3_gamma: Tensor[(256,), "float32"], stage3_unit4_bn3_beta: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit5_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit5_bn1_beta: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit5_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit5_bn2_gamma: Tensor[(256,), "float32"], stage3_unit5_bn2_beta: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit5_bn3_gamma: Tensor[(256,), "float32"], stage3_unit5_bn3_beta: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit6_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit6_bn1_beta: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit6_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit6_bn2_gamma: Tensor[(256,), "float32"], stage3_unit6_bn2_beta: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit6_bn3_gamma: Tensor[(256,), "float32"], stage3_unit6_bn3_beta: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage4_unit1_bn1_gamma: Tensor[(1024,), "float32"], stage4_unit1_bn1_beta: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_mean: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_var: Tensor[(1024,), "float32"], stage4_unit1_conv1_weight: Tensor[(512, 1024, 1, 1), "float32"], stage4_unit1_bn2_gamma: Tensor[(512,), "float32"], stage4_unit1_bn2_beta: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit1_bn3_gamma: Tensor[(512,), "float32"], stage4_unit1_bn3_beta: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit1_sc_weight: Tensor[(2048, 1024, 1, 1), "float32"], stage4_unit2_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit2_bn1_beta: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit2_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit2_bn2_gamma: Tensor[(512,), "float32"], stage4_unit2_bn2_beta: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit2_bn3_gamma: Tensor[(512,), "float32"], stage4_unit2_bn3_beta: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit3_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit3_bn1_beta: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit3_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit3_bn2_gamma: Tensor[(512,), "float32"], stage4_unit3_bn2_beta: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit3_bn3_gamma: Tensor[(512,), "float32"], stage4_unit3_bn3_beta: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], bn1_gamma: Tensor[(2048,), "float32"], bn1_beta: Tensor[(2048,), "float32"], bn1_moving_mean: Tensor[(2048,), "float32"], bn1_moving_var: Tensor[(2048,), "float32"], fc1_weight: Tensor[(1000, 2048), "float32"], fc1_bias: Tensor[(1000,), "float32"]) -> Tensor[_, "float32"]:
# block 0
with relax.dataflow():
gv: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 3, 224, 224), (3,), (3,)), batch_norm, (data, bn_data_gamma, bn_data_beta, bn_data_moving_mean, bn_data_moving_var))
gv1: Tensor[(1, 3, 224, 224), "float32"] = gv[0]
gv2: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), conv2d_nchw, (gv1, conv0_weight))
gv3: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 112, 112), (64,), (64,)), batch_norm1, (gv2, bn0_gamma, bn0_beta, bn0_moving_mean, bn0_moving_var))
gv4: Tensor[(1, 64, 112, 112), "float32"] = gv3[0]
gv5: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), relu, (gv4,))
gv6: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), max_pool2d, (gv5,))
gv7: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv6, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta, stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var))
gv8: Tensor[(1, 64, 56, 56), "float32"] = gv7[0]
gv9: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv8,))
gv10: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw1, (gv9, stage1_unit1_conv1_weight))
gv11: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv10, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta, stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var))
gv12: Tensor[(1, 64, 56, 56), "float32"] = gv11[0]
gv13: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv12,))
gv14: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv13, stage1_unit1_conv2_weight))
gv15: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv14, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta, stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var))
gv16: Tensor[(1, 64, 56, 56), "float32"] = gv15[0]
gv17: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv16,))
gv18: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv17, stage1_unit1_conv3_weight))
gv19: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv9, stage1_unit1_sc_weight))
gv20: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv18, gv19))
gv21: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv20, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta, stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var))
gv22: Tensor[(1, 256, 56, 56), "float32"] = gv21[0]
gv23: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv22,))
gv24: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (gv23, stage1_unit2_conv1_weight))
gv25: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv24, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta, stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var))
gv26: Tensor[(1, 64, 56, 56), "float32"] = gv25[0]
gv27: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv26,))
gv28: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv27, stage1_unit2_conv2_weight))
gv29: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv28, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta, stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var))
gv30: Tensor[(1, 64, 56, 56), "float32"] = gv29[0]
gv31: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv30,))
gv32: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv31, stage1_unit2_conv3_weight))
gv33: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv32, gv20))
gv34: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv33, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta, stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var))
gv35: Tensor[(1, 256, 56, 56), "float32"] = gv34[0]
gv36: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv35,))
gv37: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (gv36, stage1_unit3_conv1_weight))
gv38: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv37, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta, stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var))
gv39: Tensor[(1, 64, 56, 56), "float32"] = gv38[0]
gv40: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv39,))
gv41: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv40, stage1_unit3_conv2_weight))
gv42: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv41, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta, stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var))
gv43: Tensor[(1, 64, 56, 56), "float32"] = gv42[0]
gv44: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv43,))
gv45: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv44, stage1_unit3_conv3_weight))
gv46: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv45, gv33))
gv47: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv46, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta, stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var))
gv48: Tensor[(1, 256, 56, 56), "float32"] = gv47[0]
gv49: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv48,))
gv50: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw5, (gv49, stage2_unit1_conv1_weight))
gv51: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv50, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta, stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var))
gv52: Tensor[(1, 128, 28, 28), "float32"] = gv51[0]
gv53: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv52,))
gv54: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv53, stage2_unit1_conv2_weight))
gv55: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv54, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta, stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var))
gv56: Tensor[(1, 128, 28, 28), "float32"] = gv55[0]
gv57: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv56,))
gv58: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv57, stage2_unit1_conv3_weight))
gv59: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw8, (gv49, stage2_unit1_sc_weight))
gv60: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv58, gv59))
gv61: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv60, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta, stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var))
gv62: Tensor[(1, 512, 28, 28), "float32"] = gv61[0]
gv63: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv62,))
gv64: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv63, stage2_unit2_conv1_weight))
gv65: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv64, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta, stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var))
gv66: Tensor[(1, 128, 28, 28), "float32"] = gv65[0]
gv67: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv66,))
gv68: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv67, stage2_unit2_conv2_weight))
gv69: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv68, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta, stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var))
gv70: Tensor[(1, 128, 28, 28), "float32"] = gv69[0]
gv71: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv70,))
gv72: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv71, stage2_unit2_conv3_weight))
gv73: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv72, gv60))
gv74: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv73, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta, stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var))
gv75: Tensor[(1, 512, 28, 28), "float32"] = gv74[0]
gv76: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv75,))
gv77: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv76, stage2_unit3_conv1_weight))
gv78: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv77, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta, stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var))
gv79: Tensor[(1, 128, 28, 28), "float32"] = gv78[0]
gv80: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv79,))
gv81: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv80, stage2_unit3_conv2_weight))
gv82: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv81, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta, stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var))
gv83: Tensor[(1, 128, 28, 28), "float32"] = gv82[0]
gv84: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv83,))
gv85: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv84, stage2_unit3_conv3_weight))
gv86: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv85, gv73))
gv87: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv86, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta, stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var))
gv88: Tensor[(1, 512, 28, 28), "float32"] = gv87[0]
gv89: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv88,))
gv90: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv89, stage2_unit4_conv1_weight))
gv91: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv90, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta, stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var))
gv92: Tensor[(1, 128, 28, 28), "float32"] = gv91[0]
gv93: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv92,))
gv94: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv93, stage2_unit4_conv2_weight))
gv95: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv94, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta, stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var))
gv96: Tensor[(1, 128, 28, 28), "float32"] = gv95[0]
gv97: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv96,))
gv98: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv97, stage2_unit4_conv3_weight))
gv99: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv98, gv86))
gv100: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv99, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta, stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var))
gv101: Tensor[(1, 512, 28, 28), "float32"] = gv100[0]
gv102: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv101,))
gv103: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw10, (gv102, stage3_unit1_conv1_weight))
gv104: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv103, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta, stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var))
gv105: Tensor[(1, 256, 14, 14), "float32"] = gv104[0]
gv106: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv105,))
gv107: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv106, stage3_unit1_conv2_weight))
gv108: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv107, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta, stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var))
gv109: Tensor[(1, 256, 14, 14), "float32"] = gv108[0]
gv110: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv109,))
gv111: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv110, stage3_unit1_conv3_weight))
gv112: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw13, (gv102, stage3_unit1_sc_weight))
gv113: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv111, gv112))
gv114: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv113, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta, stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var))
gv115: Tensor[(1, 1024, 14, 14), "float32"] = gv114[0]
gv116: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv115,))
gv117: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv116, stage3_unit2_conv1_weight))
gv118: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv117, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta, stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var))
gv119: Tensor[(1, 256, 14, 14), "float32"] = gv118[0]
gv120: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv119,))
gv121: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv120, stage3_unit2_conv2_weight))
gv122: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv121, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta, stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var))
gv123: Tensor[(1, 256, 14, 14), "float32"] = gv122[0]
gv124: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv123,))
gv125: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv124, stage3_unit2_conv3_weight))
gv126: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv125, gv113))
gv127: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv126, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta, stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var))
gv128: Tensor[(1, 1024, 14, 14), "float32"] = gv127[0]
gv129: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv128,))
gv130: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv129, stage3_unit3_conv1_weight))
gv131: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv130, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta, stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var))
gv132: Tensor[(1, 256, 14, 14), "float32"] = gv131[0]
gv133: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv132,))
gv134: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv133, stage3_unit3_conv2_weight))
gv135: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv134, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta, stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var))
gv136: Tensor[(1, 256, 14, 14), "float32"] = gv135[0]
gv137: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv136,))
gv138: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv137, stage3_unit3_conv3_weight))
gv139: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv138, gv126))
gv140: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv139, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta, stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var))
gv141: Tensor[(1, 1024, 14, 14), "float32"] = gv140[0]
gv142: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv141,))
gv143: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv142, stage3_unit4_conv1_weight))
gv144: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv143, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta, stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var))
gv145: Tensor[(1, 256, 14, 14), "float32"] = gv144[0]
gv146: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv145,))
gv147: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv146, stage3_unit4_conv2_weight))
gv148: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv147, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta, stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var))
gv149: Tensor[(1, 256, 14, 14), "float32"] = gv148[0]
gv150: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv149,))
gv151: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv150, stage3_unit4_conv3_weight))
gv152: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv151, gv139))
gv153: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv152, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta, stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var))
gv154: Tensor[(1, 1024, 14, 14), "float32"] = gv153[0]
gv155: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv154,))
gv156: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv155, stage3_unit5_conv1_weight))
gv157: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv156, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta, stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var))
gv158: Tensor[(1, 256, 14, 14), "float32"] = gv157[0]
gv159: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv158,))
gv160: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv159, stage3_unit5_conv2_weight))
gv161: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv160, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta, stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var))
gv162: Tensor[(1, 256, 14, 14), "float32"] = gv161[0]
gv163: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv162,))
gv164: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv163, stage3_unit5_conv3_weight))
gv165: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv164, gv152))
gv166: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv165, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta, stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var))
gv167: Tensor[(1, 1024, 14, 14), "float32"] = gv166[0]
gv168: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv167,))
gv169: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv168, stage3_unit6_conv1_weight))
gv170: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv169, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta, stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var))
gv171: Tensor[(1, 256, 14, 14), "float32"] = gv170[0]
gv172: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv171,))
gv173: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv172, stage3_unit6_conv2_weight))
gv174: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv173, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta, stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var))
gv175: Tensor[(1, 256, 14, 14), "float32"] = gv174[0]
gv176: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv175,))
gv177: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv176, stage3_unit6_conv3_weight))
gv178: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv177, gv165))
gv179: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv178, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta, stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var))
gv180: Tensor[(1, 1024, 14, 14), "float32"] = gv179[0]
gv181: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv180,))
gv182: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw15, (gv181, stage4_unit1_conv1_weight))
gv183: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv182, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta, stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var))
gv184: Tensor[(1, 512, 7, 7), "float32"] = gv183[0]
gv185: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv184,))
gv186: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv185, stage4_unit1_conv2_weight))
gv187: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv186, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta, stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var))
gv188: Tensor[(1, 512, 7, 7), "float32"] = gv187[0]
gv189: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv188,))
gv190: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv189, stage4_unit1_conv3_weight))
gv191: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw18, (gv181, stage4_unit1_sc_weight))
gv192: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv190, gv191))
gv193: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv192, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta, stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var))
gv194: Tensor[(1, 2048, 7, 7), "float32"] = gv193[0]
gv195: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv194,))
gv196: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (gv195, stage4_unit2_conv1_weight))
gv197: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv196, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta, stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var))
gv198: Tensor[(1, 512, 7, 7), "float32"] = gv197[0]
gv199: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv198,))
gv200: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv199, stage4_unit2_conv2_weight))
gv201: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv200, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta, stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var))
gv202: Tensor[(1, 512, 7, 7), "float32"] = gv201[0]
gv203: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv202,))
gv204: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv203, stage4_unit2_conv3_weight))
gv205: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv204, gv192))
gv206: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv205, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta, stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var))
gv207: Tensor[(1, 2048, 7, 7), "float32"] = gv206[0]
gv208: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv207,))
gv209: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (gv208, stage4_unit3_conv1_weight))
gv210: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv209, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta, stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var))
gv211: Tensor[(1, 512, 7, 7), "float32"] = gv210[0]
gv212: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv211,))
gv213: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv212, stage4_unit3_conv2_weight))
gv214: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv213, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta, stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var))
gv215: Tensor[(1, 512, 7, 7), "float32"] = gv214[0]
gv216: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv215,))
gv217: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv216, stage4_unit3_conv3_weight))
gv218: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv217, gv205))
gv219: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv218, bn1_gamma, bn1_beta, bn1_moving_mean, bn1_moving_var))
gv220: Tensor[(1, 2048, 7, 7), "float32"] = gv219[0]
gv221: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv220,))
gv222: Tensor[(1, 2048, 1, 1), "float32"] = relax.call_tir((1, 2048, 1, 1), global_avg_pool2d, (gv221,))
gv223: Tensor[(1, 2048), "float32"] = relax.call_tir((1, 2048), batch_flatten, (gv222,))
gv224: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), dense, (gv223, fc1_weight))
gv225: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), bias_add, (gv224, fc1_bias))
gv226: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), softmax, (gv225,))
relax.output(gv226)
return gv226
@tir.prim_func
def relu8(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu8", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def batch_flatten(rxplaceholder_1: tir.Buffer[(1, 2048, 1, 1), "float32"], tensor_1: tir.Buffer[(1, 2048), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_flatten", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1 in tir.grid(1, 2048):
with tir.block("tensor"):
ax0, ax1 = tir.axis.remap("SS", [i0, i1])
tir.reads(rxplaceholder_1[ax0, ax1 % 2048, 0, 0])
tir.writes(tensor_1[ax0, ax1])
tensor_1[ax0, ax1] = rxplaceholder_1[ax0, ax1 % 2048, 0, 0]
@tir.prim_func
def conv2d_nchw2(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw2", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 64, 58, 58], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 58, 58):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 57 and i3_2 >= 1 and i3_2 < 57, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32")
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 3, 3):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw7(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(512, 128, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw7", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 128, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def conv2d_nchw11(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 256, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw11", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 256, 16, 16], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 16, 16):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 15 and i3_2 >= 1 and i3_2 < 15, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32")
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 256, 3, 3):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def relu2(rxplaceholder_1: tir.Buffer[(1, 256, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu2", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def relu3(rxplaceholder_1: tir.Buffer[(1, 128, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu3", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def add1(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1, 512, 28, 28), "float32"], T_add_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "add1", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
tir.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]
@tir.prim_func
def conv2d_nchw4(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw4", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 256, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def batch_norm6(rxplaceholder_5: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 14, 14), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm6", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 14, 14):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 14, 14):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(256):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(256, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(256):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(256, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def conv2d_nchw17(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(2048, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw17", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 512, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def batch_norm7(rxplaceholder_5: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(1024,), "float32"], rxplaceholder_7: tir.Buffer[(1024,), "float32"], rxplaceholder_8: tir.Buffer[(1024,), "float32"], rxplaceholder_9: tir.Buffer[(1024,), "float32"], T_add_2: tir.Buffer[(1, 1024, 14, 14), "float32"], T_multiply_3: tir.Buffer[(1024,), "float32"], T_multiply_4: tir.Buffer[(1024,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm7", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 1024])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 1024])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 1024, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 1024])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 1024, 14, 14):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 1024, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 1024])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 1024, 14, 14):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(1024):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(1024, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(1024):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(1024, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def conv2d_nchw9(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "conv2d_nchw9", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("pad_temp"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 512, 1, 1):
with tir.block("conv2d_nchw"):
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])
tir.writes(conv2d_nchw_1[nn, ff, yy, xx])
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
with tir.init():
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]
@tir.prim_func
def global_avg_pool2d(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], tensor_2: tir.Buffer[(1, 2048, 1, 1), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "global_avg_pool2d", "tir.noalias": True})
# body
# with tir.block("root")
tensor_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
for i0, i1, i2, i3, i4, i5 in tir.grid(1, 2048, 1, 1, 7, 7):
with tir.block("tensor"):
ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
tir.reads(tensor_3[ax0, ax1, ax2, ax3], rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
tir.writes(tensor_3[ax0, ax1, ax2, ax3])
with tir.init():
tensor_3[ax0, ax1, ax2, ax3] = tir.float32(0)
tensor_3[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
with tir.block("tensor_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(tensor_3[ax0, ax1, ax2, ax3])
tir.writes(tensor_2[ax0, ax1, ax2, ax3])
tensor_2[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] / (tir.cast(tir.Select(True, (ax2 + 1) * 7, (ax2 + 1) * 7 + 1) - ax2 * 7, "float32") * tir.cast(tir.Select(True, (ax3 + 1) * 7, (ax3 + 1) * 7 + 1) - ax3 * 7, "float32"))
@tir.prim_func
def relu7(rxplaceholder_1: tir.Buffer[(1, 512, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu7", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def softmax(rxplaceholder_1: tir.Buffer[(1, 1000), "float32"], T_softmax_norm_1: tir.Buffer[(1, 1000), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "softmax", "tir.noalias": True})
# body
# with tir.block("root")
T_softmax_maxelem_1 = tir.alloc_buffer([1], dtype="float32")
T_softmax_exp_1 = tir.alloc_buffer([1, 1000], dtype="float32")
T_softmax_expsum_1 = tir.alloc_buffer([1], dtype="float32")
for i0_7, i1_3 in tir.grid(1, 1000):
with tir.block("T_softmax_maxelem"):
i0_8, k = tir.axis.remap("SR", [i0_7, i1_3])
tir.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
tir.writes(T_softmax_maxelem_1[i0_8])
with tir.init():
T_softmax_maxelem_1[i0_8] = tir.float32(-3.4028234663852886e+38)
T_softmax_maxelem_1[i0_8] = tir.max(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
for i0_9, i1_4 in tir.grid(1, 1000):
with tir.block("T_softmax_exp"):
i0_10, i1_5 = tir.axis.remap("SS", [i0_9, i1_4])
tir.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])
tir.writes(T_softmax_exp_1[i0_10, i1_5])
T_softmax_exp_1[i0_10, i1_5] = tir.exp(rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32")
for i0_11, i1_6 in tir.grid(1, 1000):
with tir.block("T_softmax_expsum"):
i0_12, k = tir.axis.remap("SR", [i0_11, i1_6])
tir.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k])
tir.writes(T_softmax_expsum_1[i0_12])
with tir.init():
T_softmax_expsum_1[i0_12] = tir.float32(0)
T_softmax_expsum_1[i0_12] = T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
for i0_13, i1_7 in tir.grid(1, 1000):
with tir.block("T_softmax_norm"):
i0_14, i1_8 = tir.axis.remap("SS", [i0_13, i1_7])
tir.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])
tir.writes(T_softmax_norm_1[i0_14, i1_8])
tir.block_attr({"axis":1})
T_softmax_norm_1[i0_14, i1_8] = T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]
@tir.prim_func
def relu4(rxplaceholder_1: tir.Buffer[(1, 512, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "relu4", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
with tir.block("T_relu"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))
@tir.prim_func
def max_pool2d(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], tensor_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "max_pool2d", "tir.noalias": True})
# body
# with tir.block("root")
pad_temp_1 = tir.alloc_buffer([1, 64, 114, 114], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 64, 114, 114):
with tir.block("pad_temp"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])
tir.writes(pad_temp_1[ax0, ax1, ax2, ax3])
pad_temp_1[ax0, ax1, ax2, ax3] = tir.if_then_else(ax2 >= 1 and ax2 < 113 and ax3 >= 1 and ax3 < 113, rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], tir.float32(-3.4028234663852886e+38), dtype="float32")
for i0, i1, i2, i3, i4, i5 in tir.grid(1, 64, 56, 56, 3, 3):
with tir.block("tensor"):
ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
tir.reads(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])
tir.writes(tensor_1[ax0, ax1, ax2, ax3])
with tir.init():
tensor_1[ax0, ax1, ax2, ax3] = tir.float32(-3.4028234663852886e+38)
tensor_1[ax0, ax1, ax2, ax3] = tir.max(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])
@tir.prim_func
def batch_norm8(rxplaceholder_5: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 7, 7), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "batch_norm8", "tir.noalias": True})
# body
# with tir.block("root")
T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_subtract_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_divide_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
T_multiply_5 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):
with tir.block("T_subtract"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
tir.writes(T_add_3[ax0, ax1, ax2, ax3])
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
with tir.block("compute"):
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7):
with tir.block("T_divide"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_2"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 7, 7):
with tir.block("T_multiply"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):
with tir.block("T_reshape_3"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 7, 7):
with tir.block("T_add_1"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
tir.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
for i0_8 in tir.serial(512):
with tir.block("T_multiply_1"):
ax0 = tir.axis.spatial(512, i0_8)
tir.reads(rxplaceholder_8[ax0])
tir.writes(T_multiply_3[ax0])
T_multiply_3[ax0] = rxplaceholder_8[ax0]
for i0_9 in tir.serial(512):
with tir.block("T_multiply_2"):
ax0 = tir.axis.spatial(512, i0_9)
tir.reads(rxplaceholder_9[ax0])
tir.writes(T_multiply_4[ax0])
T_multiply_4[ax0] = rxplaceholder_9[ax0]
@tir.prim_func
def add2(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1, 1024, 14, 14), "float32"], T_add_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
# function attr dict
tir.func_attr({"global_symbol": "add2", "tir.noalias": True})
# body
# with tir.block("root")
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
with tir.block("T_add"):
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
tir.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]
"""
relax_mod = R.parser.from_source(resnet_mod_text)
R.parser.pretty_print(relax_mod["main"])
print(relax_mod["main"].body.blocks[0].bindings[1].var.checked_type)
print(relax_mod["main"].body.blocks[0].bindings[1].var.type_annotation)
print(relax_mod["main"].body.blocks[0].bindings[1].var.checked_type.rank)
print(relax_mod["main"].body.blocks[0].bindings[2].var.checked_type.rank)
@Hzfengsy
Copy link

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""Resnet models written by relax"""

from __future__ import annotations

import tvm.script
from tvm.script import tir, relax


@tvm.script.ir_module
class ResNet50:
    @tir.prim_func
    def batch_norm9(rxplaceholder_5: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(2048,), "float32"], rxplaceholder_7: tir.Buffer[(2048,), "float32"], rxplaceholder_8: tir.Buffer[(2048,), "float32"], rxplaceholder_9: tir.Buffer[(2048,), "float32"], T_add_2: tir.Buffer[(1, 2048, 7, 7), "float32"], T_multiply_3: tir.Buffer[(2048,), "float32"], T_multiply_4: tir.Buffer[(2048,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm9", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 2048])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
        for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 2048])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
        for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 2048, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 2048])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 2048, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 2048])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(2048):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(2048, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(2048):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(2048, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw3(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(256, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw3", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 256, 56, 56, 64, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 256 * 64 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 256 * 64 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def batch_norm3(rxplaceholder_5: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 56, 56), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm3", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 56, 56):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 56, 56):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(256):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(256, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(256):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(256, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def relu(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], T_relu_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def conv2d_nchw1(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw1", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 64 * 64 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 64 * 64 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu1(rxplaceholder_1: tir.Buffer[(1, 64, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu1", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def conv2d_nchw19(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 2048, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw19", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 512, 7, 7, 2048, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 512 * 2048 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 512 * 2048 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu4(rxplaceholder_1: tir.Buffer[(1, 512, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu4", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def batch_flatten(rxplaceholder_1: tir.Buffer[(1, 2048, 1, 1), "float32"], tensor_1: tir.Buffer[(1, 2048), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_flatten", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1 in tir.grid(1, 2048):
            with tir.block("tensor"):
                ax0, ax1 = tir.axis.remap("SS", [i0, i1])
                tir.reads(rxplaceholder_1[ax0, ax1 % 2048, 0, 0])
                tir.writes(tensor_1[ax0, ax1])
                tensor_1[ax0, ax1] = rxplaceholder_1[ax0, ax1 % 2048, 0, 0]

    @relax.function
    def main(data: Tensor[(1, 3, 224, 224), "float32"], bn_data_gamma: Tensor[(3,), "float32"], bn_data_beta: Tensor[(3,), "float32"], bn_data_moving_mean: Tensor[(3,), "float32"], bn_data_moving_var: Tensor[(3,), "float32"], conv0_weight: Tensor[(64, 3, 7, 7), "float32"], bn0_gamma: Tensor[(64,), "float32"], bn0_beta: Tensor[(64,), "float32"], bn0_moving_mean: Tensor[(64,), "float32"], bn0_moving_var: Tensor[(64,), "float32"], stage1_unit1_bn1_gamma: Tensor[(64,), "float32"], stage1_unit1_bn1_beta: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv1_weight: Tensor[(64, 64, 1, 1), "float32"], stage1_unit1_bn2_gamma: Tensor[(64,), "float32"], stage1_unit1_bn2_beta: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit1_bn3_gamma: Tensor[(64,), "float32"], stage1_unit1_bn3_beta: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit1_sc_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit2_bn1_gamma: Tensor[(256,), "float32"], stage1_unit2_bn1_beta: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit2_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit2_bn2_gamma: Tensor[(64,), "float32"], stage1_unit2_bn2_beta: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit2_bn3_gamma: Tensor[(64,), "float32"], stage1_unit2_bn3_beta: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit3_bn1_gamma: Tensor[(256,), "float32"], stage1_unit3_bn1_beta: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit3_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit3_bn2_gamma: Tensor[(64,), "float32"], stage1_unit3_bn2_beta: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit3_bn3_gamma: Tensor[(64,), "float32"], stage1_unit3_bn3_beta: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage2_unit1_bn1_gamma: Tensor[(256,), "float32"], stage2_unit1_bn1_beta: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_mean: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_var: Tensor[(256,), "float32"], stage2_unit1_conv1_weight: Tensor[(128, 256, 1, 1), "float32"], stage2_unit1_bn2_gamma: Tensor[(128,), "float32"], stage2_unit1_bn2_beta: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit1_bn3_gamma: Tensor[(128,), "float32"], stage2_unit1_bn3_beta: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit1_sc_weight: Tensor[(512, 256, 1, 1), "float32"], stage2_unit2_bn1_gamma: Tensor[(512,), "float32"], stage2_unit2_bn1_beta: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit2_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit2_bn2_gamma: Tensor[(128,), "float32"], stage2_unit2_bn2_beta: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit2_bn3_gamma: Tensor[(128,), "float32"], stage2_unit2_bn3_beta: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit3_bn1_gamma: Tensor[(512,), "float32"], stage2_unit3_bn1_beta: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit3_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit3_bn2_gamma: Tensor[(128,), "float32"], stage2_unit3_bn2_beta: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit3_bn3_gamma: Tensor[(128,), "float32"], stage2_unit3_bn3_beta: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit4_bn1_gamma: Tensor[(512,), "float32"], stage2_unit4_bn1_beta: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit4_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit4_bn2_gamma: Tensor[(128,), "float32"], stage2_unit4_bn2_beta: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit4_bn3_gamma: Tensor[(128,), "float32"], stage2_unit4_bn3_beta: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage3_unit1_bn1_gamma: Tensor[(512,), "float32"], stage3_unit1_bn1_beta: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_mean: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_var: Tensor[(512,), "float32"], stage3_unit1_conv1_weight: Tensor[(256, 512, 1, 1), "float32"], stage3_unit1_bn2_gamma: Tensor[(256,), "float32"], stage3_unit1_bn2_beta: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit1_bn3_gamma: Tensor[(256,), "float32"], stage3_unit1_bn3_beta: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit1_sc_weight: Tensor[(1024, 512, 1, 1), "float32"], stage3_unit2_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit2_bn1_beta: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit2_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit2_bn2_gamma: Tensor[(256,), "float32"], stage3_unit2_bn2_beta: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit2_bn3_gamma: Tensor[(256,), "float32"], stage3_unit2_bn3_beta: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit3_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit3_bn1_beta: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit3_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit3_bn2_gamma: Tensor[(256,), "float32"], stage3_unit3_bn2_beta: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit3_bn3_gamma: Tensor[(256,), "float32"], stage3_unit3_bn3_beta: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit4_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit4_bn1_beta: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit4_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit4_bn2_gamma: Tensor[(256,), "float32"], stage3_unit4_bn2_beta: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit4_bn3_gamma: Tensor[(256,), "float32"], stage3_unit4_bn3_beta: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit5_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit5_bn1_beta: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit5_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit5_bn2_gamma: Tensor[(256,), "float32"], stage3_unit5_bn2_beta: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit5_bn3_gamma: Tensor[(256,), "float32"], stage3_unit5_bn3_beta: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit6_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit6_bn1_beta: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit6_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit6_bn2_gamma: Tensor[(256,), "float32"], stage3_unit6_bn2_beta: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit6_bn3_gamma: Tensor[(256,), "float32"], stage3_unit6_bn3_beta: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage4_unit1_bn1_gamma: Tensor[(1024,), "float32"], stage4_unit1_bn1_beta: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_mean: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_var: Tensor[(1024,), "float32"], stage4_unit1_conv1_weight: Tensor[(512, 1024, 1, 1), "float32"], stage4_unit1_bn2_gamma: Tensor[(512,), "float32"], stage4_unit1_bn2_beta: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit1_bn3_gamma: Tensor[(512,), "float32"], stage4_unit1_bn3_beta: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit1_sc_weight: Tensor[(2048, 1024, 1, 1), "float32"], stage4_unit2_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit2_bn1_beta: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit2_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit2_bn2_gamma: Tensor[(512,), "float32"], stage4_unit2_bn2_beta: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit2_bn3_gamma: Tensor[(512,), "float32"], stage4_unit2_bn3_beta: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit3_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit3_bn1_beta: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit3_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit3_bn2_gamma: Tensor[(512,), "float32"], stage4_unit3_bn2_beta: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit3_bn3_gamma: Tensor[(512,), "float32"], stage4_unit3_bn3_beta: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], bn1_gamma: Tensor[(2048,), "float32"], bn1_beta: Tensor[(2048,), "float32"], bn1_moving_mean: Tensor[(2048,), "float32"], bn1_moving_var: Tensor[(2048,), "float32"], fc1_weight: Tensor[(1000, 2048), "float32"], fc1_bias: Tensor[(1000,), "float32"]) -> Tensor[_, "float32"]:
        # block 0
        with relax.dataflow():
            lv: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 3, 224, 224), (3,), (3,)), batch_norm, (data, bn_data_gamma, bn_data_beta, bn_data_moving_mean, bn_data_moving_var))
            lv1: Tensor[(1, 3, 224, 224), "float32"] = lv[0]
            lv2: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), conv2d_nchw, (lv1, conv0_weight))
            lv3: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 112, 112), (64,), (64,)), batch_norm1, (lv2, bn0_gamma, bn0_beta, bn0_moving_mean, bn0_moving_var))
            lv4: Tensor[(1, 64, 112, 112), "float32"] = lv3[0]
            lv5: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), relu, (lv4,))
            lv6: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), max_pool2d, (lv5,))
            lv7: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv6, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta, stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var))
            lv8: Tensor[(1, 64, 56, 56), "float32"] = lv7[0]
            lv9: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv8,))
            lv10: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw1, (lv9, stage1_unit1_conv1_weight))
            lv11: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv10, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta, stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var))
            lv12: Tensor[(1, 64, 56, 56), "float32"] = lv11[0]
            lv13: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv12,))
            lv14: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (lv13, stage1_unit1_conv2_weight))
            lv15: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv14, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta, stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var))
            lv16: Tensor[(1, 64, 56, 56), "float32"] = lv15[0]
            lv17: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv16,))
            lv18: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (lv17, stage1_unit1_conv3_weight))
            lv19: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (lv9, stage1_unit1_sc_weight))
            lv20: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add0, (lv18, lv19))
            lv21: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (lv20, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta, stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var))
            lv22: Tensor[(1, 256, 56, 56), "float32"] = lv21[0]
            lv23: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (lv22,))
            lv24: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (lv23, stage1_unit2_conv1_weight))
            lv25: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv24, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta, stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var))
            lv26: Tensor[(1, 64, 56, 56), "float32"] = lv25[0]
            lv27: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv26,))
            lv28: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (lv27, stage1_unit2_conv2_weight))
            lv29: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv28, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta, stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var))
            lv30: Tensor[(1, 64, 56, 56), "float32"] = lv29[0]
            lv31: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv30,))
            lv32: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (lv31, stage1_unit2_conv3_weight))
            lv33: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add0, (lv32, lv20))
            lv34: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (lv33, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta, stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var))
            lv35: Tensor[(1, 256, 56, 56), "float32"] = lv34[0]
            lv36: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (lv35,))
            lv37: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (lv36, stage1_unit3_conv1_weight))
            lv38: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv37, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta, stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var))
            lv39: Tensor[(1, 64, 56, 56), "float32"] = lv38[0]
            lv40: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv39,))
            lv41: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (lv40, stage1_unit3_conv2_weight))
            lv42: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (lv41, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta, stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var))
            lv43: Tensor[(1, 64, 56, 56), "float32"] = lv42[0]
            lv44: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (lv43,))
            lv45: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (lv44, stage1_unit3_conv3_weight))
            lv46: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add0, (lv45, lv33))
            lv47: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (lv46, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta, stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var))
            lv48: Tensor[(1, 256, 56, 56), "float32"] = lv47[0]
            lv49: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (lv48,))
            lv50: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw5, (lv49, stage2_unit1_conv1_weight))
            lv51: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv50, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta, stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var))
            lv52: Tensor[(1, 128, 28, 28), "float32"] = lv51[0]
            lv53: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv52,))
            lv54: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (lv53, stage2_unit1_conv2_weight))
            lv55: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv54, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta, stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var))
            lv56: Tensor[(1, 128, 28, 28), "float32"] = lv55[0]
            lv57: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv56,))
            lv58: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (lv57, stage2_unit1_conv3_weight))
            lv59: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw8, (lv49, stage2_unit1_sc_weight))
            lv60: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (lv58, lv59))
            lv61: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (lv60, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta, stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var))
            lv62: Tensor[(1, 512, 28, 28), "float32"] = lv61[0]
            lv63: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (lv62,))
            lv64: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (lv63, stage2_unit2_conv1_weight))
            lv65: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv64, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta, stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var))
            lv66: Tensor[(1, 128, 28, 28), "float32"] = lv65[0]
            lv67: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv66,))
            lv68: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (lv67, stage2_unit2_conv2_weight))
            lv69: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv68, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta, stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var))
            lv70: Tensor[(1, 128, 28, 28), "float32"] = lv69[0]
            lv71: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv70,))
            lv72: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (lv71, stage2_unit2_conv3_weight))
            lv73: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (lv72, lv60))
            lv74: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (lv73, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta, stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var))
            lv75: Tensor[(1, 512, 28, 28), "float32"] = lv74[0]
            lv76: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (lv75,))
            lv77: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (lv76, stage2_unit3_conv1_weight))
            lv78: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv77, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta, stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var))
            lv79: Tensor[(1, 128, 28, 28), "float32"] = lv78[0]
            lv80: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv79,))
            lv81: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (lv80, stage2_unit3_conv2_weight))
            lv82: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv81, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta, stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var))
            lv83: Tensor[(1, 128, 28, 28), "float32"] = lv82[0]
            lv84: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv83,))
            lv85: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (lv84, stage2_unit3_conv3_weight))
            lv86: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (lv85, lv73))
            lv87: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (lv86, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta, stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var))
            lv88: Tensor[(1, 512, 28, 28), "float32"] = lv87[0]
            lv89: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (lv88,))
            lv90: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (lv89, stage2_unit4_conv1_weight))
            lv91: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv90, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta, stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var))
            lv92: Tensor[(1, 128, 28, 28), "float32"] = lv91[0]
            lv93: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv92,))
            lv94: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (lv93, stage2_unit4_conv2_weight))
            lv95: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (lv94, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta, stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var))
            lv96: Tensor[(1, 128, 28, 28), "float32"] = lv95[0]
            lv97: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (lv96,))
            lv98: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (lv97, stage2_unit4_conv3_weight))
            lv99: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (lv98, lv86))
            lv100: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (lv99, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta, stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var))
            lv101: Tensor[(1, 512, 28, 28), "float32"] = lv100[0]
            lv102: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (lv101,))
            lv103: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw10, (lv102, stage3_unit1_conv1_weight))
            lv104: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv103, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta, stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var))
            lv105: Tensor[(1, 256, 14, 14), "float32"] = lv104[0]
            lv106: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv105,))
            lv107: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv106, stage3_unit1_conv2_weight))
            lv108: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv107, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta, stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var))
            lv109: Tensor[(1, 256, 14, 14), "float32"] = lv108[0]
            lv110: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv109,))
            lv111: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv110, stage3_unit1_conv3_weight))
            lv112: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw13, (lv102, stage3_unit1_sc_weight))
            lv113: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv111, lv112))
            lv114: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv113, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta, stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var))
            lv115: Tensor[(1, 1024, 14, 14), "float32"] = lv114[0]
            lv116: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv115,))
            lv117: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (lv116, stage3_unit2_conv1_weight))
            lv118: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv117, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta, stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var))
            lv119: Tensor[(1, 256, 14, 14), "float32"] = lv118[0]
            lv120: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv119,))
            lv121: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv120, stage3_unit2_conv2_weight))
            lv122: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv121, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta, stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var))
            lv123: Tensor[(1, 256, 14, 14), "float32"] = lv122[0]
            lv124: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv123,))
            lv125: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv124, stage3_unit2_conv3_weight))
            lv126: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv125, lv113))
            lv127: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv126, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta, stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var))
            lv128: Tensor[(1, 1024, 14, 14), "float32"] = lv127[0]
            lv129: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv128,))
            lv130: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (lv129, stage3_unit3_conv1_weight))
            lv131: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv130, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta, stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var))
            lv132: Tensor[(1, 256, 14, 14), "float32"] = lv131[0]
            lv133: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv132,))
            lv134: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv133, stage3_unit3_conv2_weight))
            lv135: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv134, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta, stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var))
            lv136: Tensor[(1, 256, 14, 14), "float32"] = lv135[0]
            lv137: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv136,))
            lv138: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv137, stage3_unit3_conv3_weight))
            lv139: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv138, lv126))
            lv140: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv139, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta, stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var))
            lv141: Tensor[(1, 1024, 14, 14), "float32"] = lv140[0]
            lv142: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv141,))
            lv143: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (lv142, stage3_unit4_conv1_weight))
            lv144: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv143, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta, stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var))
            lv145: Tensor[(1, 256, 14, 14), "float32"] = lv144[0]
            lv146: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv145,))
            lv147: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv146, stage3_unit4_conv2_weight))
            lv148: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv147, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta, stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var))
            lv149: Tensor[(1, 256, 14, 14), "float32"] = lv148[0]
            lv150: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv149,))
            lv151: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv150, stage3_unit4_conv3_weight))
            lv152: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv151, lv139))
            lv153: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv152, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta, stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var))
            lv154: Tensor[(1, 1024, 14, 14), "float32"] = lv153[0]
            lv155: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv154,))
            lv156: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (lv155, stage3_unit5_conv1_weight))
            lv157: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv156, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta, stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var))
            lv158: Tensor[(1, 256, 14, 14), "float32"] = lv157[0]
            lv159: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv158,))
            lv160: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv159, stage3_unit5_conv2_weight))
            lv161: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv160, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta, stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var))
            lv162: Tensor[(1, 256, 14, 14), "float32"] = lv161[0]
            lv163: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv162,))
            lv164: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv163, stage3_unit5_conv3_weight))
            lv165: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv164, lv152))
            lv166: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv165, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta, stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var))
            lv167: Tensor[(1, 1024, 14, 14), "float32"] = lv166[0]
            lv168: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv167,))
            lv169: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (lv168, stage3_unit6_conv1_weight))
            lv170: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv169, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta, stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var))
            lv171: Tensor[(1, 256, 14, 14), "float32"] = lv170[0]
            lv172: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv171,))
            lv173: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (lv172, stage3_unit6_conv2_weight))
            lv174: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (lv173, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta, stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var))
            lv175: Tensor[(1, 256, 14, 14), "float32"] = lv174[0]
            lv176: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (lv175,))
            lv177: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (lv176, stage3_unit6_conv3_weight))
            lv178: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (lv177, lv165))
            lv179: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (lv178, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta, stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var))
            lv180: Tensor[(1, 1024, 14, 14), "float32"] = lv179[0]
            lv181: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (lv180,))
            lv182: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw15, (lv181, stage4_unit1_conv1_weight))
            lv183: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv182, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta, stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var))
            lv184: Tensor[(1, 512, 7, 7), "float32"] = lv183[0]
            lv185: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv184,))
            lv186: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (lv185, stage4_unit1_conv2_weight))
            lv187: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv186, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta, stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var))
            lv188: Tensor[(1, 512, 7, 7), "float32"] = lv187[0]
            lv189: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv188,))
            lv190: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (lv189, stage4_unit1_conv3_weight))
            lv191: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw18, (lv181, stage4_unit1_sc_weight))
            lv192: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (lv190, lv191))
            lv193: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (lv192, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta, stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var))
            lv194: Tensor[(1, 2048, 7, 7), "float32"] = lv193[0]
            lv195: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (lv194,))
            lv196: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (lv195, stage4_unit2_conv1_weight))
            lv197: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv196, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta, stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var))
            lv198: Tensor[(1, 512, 7, 7), "float32"] = lv197[0]
            lv199: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv198,))
            lv200: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (lv199, stage4_unit2_conv2_weight))
            lv201: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv200, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta, stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var))
            lv202: Tensor[(1, 512, 7, 7), "float32"] = lv201[0]
            lv203: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv202,))
            lv204: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (lv203, stage4_unit2_conv3_weight))
            lv205: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (lv204, lv192))
            lv206: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (lv205, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta, stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var))
            lv207: Tensor[(1, 2048, 7, 7), "float32"] = lv206[0]
            lv208: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (lv207,))
            lv209: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (lv208, stage4_unit3_conv1_weight))
            lv210: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv209, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta, stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var))
            lv211: Tensor[(1, 512, 7, 7), "float32"] = lv210[0]
            lv212: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv211,))
            lv213: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (lv212, stage4_unit3_conv2_weight))
            lv214: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (lv213, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta, stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var))
            lv215: Tensor[(1, 512, 7, 7), "float32"] = lv214[0]
            lv216: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (lv215,))
            lv217: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (lv216, stage4_unit3_conv3_weight))
            lv218: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (lv217, lv205))
            lv219: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (lv218, bn1_gamma, bn1_beta, bn1_moving_mean, bn1_moving_var))
            lv220: Tensor[(1, 2048, 7, 7), "float32"] = lv219[0]
            lv221: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (lv220,))
            lv222: Tensor[(1, 2048, 1, 1), "float32"] = relax.call_tir((1, 2048, 1, 1), global_avg_pool2d, (lv221,))
            lv223: Tensor[(1, 2048), "float32"] = relax.call_tir((1, 2048), batch_flatten, (lv222,))
            lv224: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), dense, (lv223, fc1_weight))
            lv225: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), bias_add, (lv224, fc1_bias))
            gv0: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), softmax, (lv225,))
            relax.output(gv0)

        return gv0

    @tir.prim_func
    def conv2d_nchw13(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1024, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw13", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 512, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 1024 * 512 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 1024 * 512 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw16(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 512, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw16", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 512, 9, 9], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 9, 9):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, tir.max(i2_4 - 1, 0) : tir.max(i2_4 - 1, 0) + (tir.min(i2_4, 7) - tir.max(i2_4 - 1, 0)), tir.max(i3_4 - 1, 0) : tir.max(i3_4 - 1, 0) + (tir.min(i3_4, 7) - tir.max(i3_4 - 1, 0))])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = tir.if_then_else(i2_4 >= 1 and i2_4 < 8 and i3_4 >= 1 and i3_4 < 8, rxplaceholder_2[i0_4, i1_4, i2_4 - 1, i3_4 - 1], tir.float32(0), dtype="float32")
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 512, 7, 7, 512, 3, 3):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 512 * 512 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 512 * 512 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def bias_add(rxplaceholder_2: tir.Buffer[(1, 1000), "float32"], rxplaceholder_3: tir.Buffer[(1000,), "float32"], T_add_1: tir.Buffer[(1, 1000), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bias_add", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1 in tir.grid(1, 1000):
            with tir.block("T_add"):
                ax0, ax1 = tir.axis.remap("SS", [i0, i1])
                tir.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])
                tir.writes(T_add_1[ax0, ax1])
                T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]

    @tir.prim_func
    def batch_norm(rxplaceholder_5: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_6: tir.Buffer[(3,), "float32"], rxplaceholder_7: tir.Buffer[(3,), "float32"], rxplaceholder_8: tir.Buffer[(3,), "float32"], rxplaceholder_9: tir.Buffer[(3,), "float32"], T_add_2: tir.Buffer[(1, 3, 224, 224), "float32"], T_multiply_2: tir.Buffer[(3,), "float32"], T_multiply_3: tir.Buffer[(3,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32")
        T_reshape_4 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 3])
                tir.writes(T_reshape_3[ax0, ax1, ax2, ax3])
                T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
        for i0, i1, i2, i3 in tir.grid(1, 3, 224, 224):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_3[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_3[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 3])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
        for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_4[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_4[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 224, 224):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 3, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 3])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 3 + ax1 + ax2 + ax3) % 3]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 3, 224, 224):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_5[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T_reshape_5[ax0, ax1, 0, 0]
        for i0_6 in tir.serial(3):
            with tir.block("T_multiply"):
                ax0 = tir.axis.spatial(3, i0_6)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_2[ax0])
                T_multiply_2[ax0] = rxplaceholder_8[ax0]
        for i0_7 in tir.serial(3):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(3, i0_7)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw(rxplaceholder_2: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_3: tir.Buffer[(64, 3, 7, 7), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 3, 230, 230], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 230, 230):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, tir.max(i2_4 - 3, 0) : tir.max(i2_4 - 3, 0) + (tir.min(i2_4 - 3, 223) + 1 - tir.max(i2_4 - 3, 0)), tir.max(i3_4 - 3, 0) : tir.max(i3_4 - 3, 0) + (tir.min(i3_4 - 3, 223) + 1 - tir.max(i3_4 - 3, 0))])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = tir.if_then_else(i2_4 >= 3 and i2_4 < 227 and i3_4 >= 3 and i3_4 < 227, rxplaceholder_2[i0_4, i1_4, i2_4 - 3, i3_4 - 3], tir.float32(0), dtype="float32")
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 64, 112, 112, 3, 7, 7):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 64 * 3 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 64 * 3 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw7(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(512, 128, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw7", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 512, 28, 28, 128, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 512 * 128 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 512 * 128 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def add2(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1, 1024, 14, 14), "float32"], T_add_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add2", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
                tir.writes(T_add_1[ax0, ax1, ax2, ax3])
                T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]

    @tir.prim_func
    def batch_norm1(rxplaceholder_5: tir.Buffer[(1, 64, 112, 112), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 112, 112), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm1", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 112, 112):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 112, 112):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 112, 112):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(64):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(64, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(64):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(64, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def batch_norm2(rxplaceholder_5: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 56, 56), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm2", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 56, 56):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 56, 56):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(64):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(64, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(64):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(64, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def add0(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(1, 256, 56, 56), "float32"], T_add_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add0", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
                tir.writes(T_add_1[ax0, ax1, ax2, ax3])
                T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]

    @tir.prim_func
    def add1(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1, 512, 28, 28), "float32"], T_add_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add1", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
                tir.writes(T_add_1[ax0, ax1, ax2, ax3])
                T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]

    @tir.prim_func
    def conv2d_nchw14(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw14", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 256, 14, 14, 1024, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 256 * 1024 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 256 * 1024 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def batch_norm5(rxplaceholder_5: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 28, 28), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm5", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 28, 28):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 28, 28):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(512):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(512, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(512):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(512, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw9(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw9", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 128, 28, 28, 512, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 128 * 512 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 128 * 512 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw17(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(2048, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw17", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 512, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 2048 * 512 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 2048 * 512 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu5(rxplaceholder_1: tir.Buffer[(1, 256, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu5", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def conv2d_nchw12(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1024, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw12", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 256, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 1024 * 256 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 1024 * 256 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw15(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(512, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw15", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 512, 7, 7, 1024, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 512 * 1024 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 512 * 1024 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def batch_norm7(rxplaceholder_5: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(1024,), "float32"], rxplaceholder_7: tir.Buffer[(1024,), "float32"], rxplaceholder_8: tir.Buffer[(1024,), "float32"], rxplaceholder_9: tir.Buffer[(1024,), "float32"], T_add_2: tir.Buffer[(1, 1024, 14, 14), "float32"], T_multiply_3: tir.Buffer[(1024,), "float32"], T_multiply_4: tir.Buffer[(1024,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm7", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 1024])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
        for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 1024])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
        for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 1024, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 1024])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 1024, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 1024])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(1024):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(1024, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(1024):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(1024, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw5(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(128, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw5", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 128, 28, 28, 256, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 128 * 256 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 128 * 256 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw11(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 256, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw11", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 256, 16, 16], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 16, 16):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, tir.max(i2_4 - 1, 0) : tir.max(i2_4 - 1, 0) + (tir.min(i2_4, 14) - tir.max(i2_4 - 1, 0)), tir.max(i3_4 - 1, 0) : tir.max(i3_4 - 1, 0) + (tir.min(i3_4, 14) - tir.max(i3_4 - 1, 0))])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = tir.if_then_else(i2_4 >= 1 and i2_4 < 15 and i3_4 >= 1 and i3_4 < 15, rxplaceholder_2[i0_4, i1_4, i2_4 - 1, i3_4 - 1], tir.float32(0), dtype="float32")
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 256, 14, 14, 256, 3, 3):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 256 * 256 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 256 * 256 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu2(rxplaceholder_1: tir.Buffer[(1, 256, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu2", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def softmax(rxplaceholder_1: tir.Buffer[(1, 1000), "float32"], T_softmax_norm_1: tir.Buffer[(1, 1000), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "softmax", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_softmax_maxelem_1 = tir.alloc_buffer([1], dtype="float32")
        T_softmax_exp_1 = tir.alloc_buffer([1, 1000], dtype="float32")
        T_softmax_expsum_1 = tir.alloc_buffer([1], dtype="float32")
        for i0_7, i1_3 in tir.grid(1, 1000):
            with tir.block("T_softmax_maxelem"):
                i0_8, k = tir.axis.remap("SR", [i0_7, i1_3])
                tir.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
                tir.writes(T_softmax_maxelem_1[i0_8])
                with tir.init():
                    T_softmax_maxelem_1[i0_8] = tir.float32(-3.4028234663852886e+38)
                T_softmax_maxelem_1[i0_8] = tir.max(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
        for i0_9, i1_4 in tir.grid(1, 1000):
            with tir.block("T_softmax_exp"):
                i0_10, i1_5 = tir.axis.remap("SS", [i0_9, i1_4])
                tir.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])
                tir.writes(T_softmax_exp_1[i0_10, i1_5])
                T_softmax_exp_1[i0_10, i1_5] = tir.exp(rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32")
        for i0_11, i1_6 in tir.grid(1, 1000):
            with tir.block("T_softmax_expsum"):
                i0_12, k = tir.axis.remap("SR", [i0_11, i1_6])
                tir.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k])
                tir.writes(T_softmax_expsum_1[i0_12])
                with tir.init():
                    T_softmax_expsum_1[i0_12] = tir.float32(0)
                T_softmax_expsum_1[i0_12] = T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
        for i0_13, i1_7 in tir.grid(1, 1000):
            with tir.block("T_softmax_norm"):
                i0_14, i1_8 = tir.axis.remap("SS", [i0_13, i1_7])
                tir.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])
                tir.writes(T_softmax_norm_1[i0_14, i1_8])
                tir.block_attr({"axis":1})
                T_softmax_norm_1[i0_14, i1_8] = T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]

    @tir.prim_func
    def conv2d_nchw4(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw4", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 64, 56, 56, 256, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 64 * 256 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 64 * 256 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw2(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw2", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 64, 58, 58], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 58, 58):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, tir.max(i2_4 - 1, 0) : tir.max(i2_4 - 1, 0) + (tir.min(i2_4, 56) - tir.max(i2_4 - 1, 0)), tir.max(i3_4 - 1, 0) : tir.max(i3_4 - 1, 0) + (tir.min(i3_4, 56) - tir.max(i3_4 - 1, 0))])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = tir.if_then_else(i2_4 >= 1 and i2_4 < 57 and i3_4 >= 1 and i3_4 < 57, rxplaceholder_2[i0_4, i1_4, i2_4 - 1, i3_4 - 1], tir.float32(0), dtype="float32")
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 3, 3):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 64 * 64 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 64 * 64 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def conv2d_nchw8(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(512, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw8", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 512, 28, 28, 256, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 512 * 256 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 512 * 256 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def max_pool2d(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], tensor_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "max_pool2d", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 64, 114, 114], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 64, 114, 114):
            with tir.block("pad_temp"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, tir.max(ax2 - 1, 0) : tir.max(ax2 - 1, 0) + (tir.min(ax2, 112) - tir.max(ax2 - 1, 0)), tir.max(ax3 - 1, 0) : tir.max(ax3 - 1, 0) + (tir.min(ax3, 112) - tir.max(ax3 - 1, 0))])
                tir.writes(pad_temp_1[ax0, ax1, ax2, ax3])
                pad_temp_1[ax0, ax1, ax2, ax3] = tir.if_then_else(ax2 >= 1 and ax2 < 113 and ax3 >= 1 and ax3 < 113, rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], tir.float32(-3.4028234663852886e+38), dtype="float32")
        for i0, i1, i2, i3, i4, i5 in tir.grid(1, 64, 56, 56, 3, 3):
            with tir.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                tir.reads(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])
                tir.writes(tensor_1[ax0, ax1, ax2, ax3])
                with tir.init():
                    tensor_1[ax0, ax1, ax2, ax3] = tir.float32(-3.4028234663852886e+38)
                tensor_1[ax0, ax1, ax2, ax3] = tir.max(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])

    @tir.prim_func
    def batch_norm6(rxplaceholder_5: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 14, 14), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm6", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 14, 14):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 14, 14):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(256):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(256, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(256):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(256, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def batch_norm8(rxplaceholder_5: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 7, 7), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm8", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 7, 7):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 7, 7):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(512):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(512, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(512):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(512, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw18(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(2048, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw18", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 1024, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 2048 * 1024 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 2048 * 1024 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu6(rxplaceholder_1: tir.Buffer[(1, 1024, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu6", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def global_avg_pool2d(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], tensor_2: tir.Buffer[(1, 2048, 1, 1), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "global_avg_pool2d", "tir.noalias": True})
        # body
        # with tir.block("root")
        tensor_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32")
        for i0, i1, i2, i3, i4, i5 in tir.grid(1, 2048, 1, 1, 7, 7):
            with tir.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                tir.reads(tensor_3[ax0, ax1, ax2, ax3], rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])
                tir.writes(tensor_3[ax0, ax1, ax2, ax3])
                with tir.init():
                    tensor_3[ax0, ax1, ax2, ax3] = tir.float32(0)
                tensor_3[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]
        for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):
            with tir.block("tensor_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(tensor_3[ax0, ax1, ax2, ax3])
                tir.writes(tensor_2[ax0, ax1, ax2, ax3])
                tensor_2[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] / (tir.cast(tir.Select(True, (ax2 + 1) * 7, (ax2 + 1) * 7 + 1) - ax2 * 7, "float32") * tir.cast(tir.Select(True, (ax3 + 1) * 7, (ax3 + 1) * 7 + 1) - ax3 * 7, "float32"))

    @tir.prim_func
    def conv2d_nchw10(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(256, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw10", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, i2_4, i3_4])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = rxplaceholder_2[i0_4, i1_4, i2_4, i3_4]
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 256, 14, 14, 512, 1, 1):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 256 * 512 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 256 * 512 + rc, i2_6 * 2 + r0, i3_6 * 2 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def dense(rxplaceholder_2: tir.Buffer[(1, 2048), "float32"], rxplaceholder_3: tir.Buffer[(1000, 2048), "float32"], T_matmul_NT_1: tir.Buffer[(1, 1000), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "dense", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2 in tir.grid(1, 1000, 2048):
            with tir.block("T_matmul_NT"):
                i, j, k = tir.axis.remap("SSR", [i0, i1, i2])
                tir.reads(T_matmul_NT_1[i, j], rxplaceholder_2[i, k], rxplaceholder_3[j, k])
                tir.writes(T_matmul_NT_1[i, j])
                tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]})
                with tir.init():
                    T_matmul_NT_1[i, j] = tir.float32(0)
                T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + rxplaceholder_2[i, k] * rxplaceholder_3[j, k]

    @tir.prim_func
    def batch_norm4(rxplaceholder_5: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(128,), "float32"], rxplaceholder_7: tir.Buffer[(128,), "float32"], rxplaceholder_8: tir.Buffer[(128,), "float32"], rxplaceholder_9: tir.Buffer[(128,), "float32"], T_add_2: tir.Buffer[(1, 128, 28, 28), "float32"], T_multiply_3: tir.Buffer[(128,), "float32"], T_multiply_4: tir.Buffer[(128,), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "batch_norm4", "tir.noalias": True})
        # body
        # with tir.block("root")
        T_reshape_4 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        T_subtract_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
        T_reshape_5 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        T_add_3 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        compute_1 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        T_divide_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
        T_reshape_6 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        T_multiply_5 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32")
        T_reshape_7 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32")
        for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
            with tir.block("T_reshape"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 128])
                tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])
                T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
        for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):
            with tir.block("T_subtract"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])
                tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])
                T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]
        for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
            with tir.block("T_reshape_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 128])
                tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])
                T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
        for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])
                tir.writes(T_add_3[ax0, ax1, ax2, ax3])
                T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)
        for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):
            with tir.block("compute"):
                i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])
                tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
                compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28):
            with tir.block("T_divide"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])
                tir.writes(T_divide_1[ax0, ax1, ax2, ax3])
                T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]
        for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 128, 1, 1):
            with tir.block("T_reshape_2"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
                tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 128])
                tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])
                T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
        for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 128, 28, 28):
            with tir.block("T_multiply"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5])
                tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])
                tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])
                T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]
        for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 128, 1, 1):
            with tir.block("T_reshape_3"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
                tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 128])
                tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])
                T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 128 + ax1 + ax2 + ax3) % 128]
        for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 128, 28, 28):
            with tir.block("T_add_1"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
                tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])
                tir.writes(T_add_2[ax0, ax1, ax2, ax3])
                T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]
        for i0_8 in tir.serial(128):
            with tir.block("T_multiply_1"):
                ax0 = tir.axis.spatial(128, i0_8)
                tir.reads(rxplaceholder_8[ax0])
                tir.writes(T_multiply_3[ax0])
                T_multiply_3[ax0] = rxplaceholder_8[ax0]
        for i0_9 in tir.serial(128):
            with tir.block("T_multiply_2"):
                ax0 = tir.axis.spatial(128, i0_9)
                tir.reads(rxplaceholder_9[ax0])
                tir.writes(T_multiply_4[ax0])
                T_multiply_4[ax0] = rxplaceholder_9[ax0]

    @tir.prim_func
    def conv2d_nchw6(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 128, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "conv2d_nchw6", "tir.noalias": True})
        # body
        # with tir.block("root")
        pad_temp_1 = tir.alloc_buffer([1, 128, 30, 30], dtype="float32")
        for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 30, 30):
            with tir.block("pad_temp"):
                i0_4, i1_4, i2_4, i3_4 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
                tir.reads(rxplaceholder_2[i0_4, i1_4, tir.max(i2_4 - 1, 0) : tir.max(i2_4 - 1, 0) + (tir.min(i2_4, 28) - tir.max(i2_4 - 1, 0)), tir.max(i3_4 - 1, 0) : tir.max(i3_4 - 1, 0) + (tir.min(i3_4, 28) - tir.max(i3_4 - 1, 0))])
                tir.writes(pad_temp_1[i0_4, i1_4, i2_4, i3_4])
                pad_temp_1[i0_4, i1_4, i2_4, i3_4] = tir.if_then_else(i2_4 >= 1 and i2_4 < 29 and i3_4 >= 1 and i3_4 < 29, rxplaceholder_2[i0_4, i1_4, i2_4 - 1, i3_4 - 1], tir.float32(0), dtype="float32")
        for i0_5, i1_5, i2_5, i3_5, i4, i5, i6 in tir.grid(1, 128, 28, 28, 128, 3, 3):
            with tir.block("conv2d_nchw"):
                i0_6, i1_6, i2_6, i3_6, rc, r0, r1 = tir.axis.remap("SSSSRRR", [i0_5, i1_5, i2_5, i3_5, i4, i5, i6])
                tir.reads(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6], pad_temp_1[i0_6, i1_6 // 128 * 128 + rc, i2_6 + r0, i3_6 + r1], rxplaceholder_3[i1_6, rc, r0, r1])
                tir.writes(conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6])
                with tir.init():
                    conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = tir.float32(0)
                conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] = conv2d_nchw_1[i0_6, i1_6, i2_6, i3_6] + pad_temp_1[i0_6, i1_6 // 128 * 128 + rc, i2_6 + r0, i3_6 + r1] * rxplaceholder_3[i1_6, rc, r0, r1]

    @tir.prim_func
    def relu7(rxplaceholder_1: tir.Buffer[(1, 512, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu7", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def relu8(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu8", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def relu3(rxplaceholder_1: tir.Buffer[(1, 128, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "relu3", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):
            with tir.block("T_relu"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])
                tir.writes(T_relu_1[ax0, ax1, ax2, ax3])
                T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))

    @tir.prim_func
    def add3(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(1, 2048, 7, 7), "float32"], T_add_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "add3", "tir.noalias": True})
        # body
        # with tir.block("root")
        for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):
            with tir.block("T_add"):
                ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3])
                tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])
                tir.writes(T_add_1[ax0, ax1, ax2, ax3])
                T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment