Last active
March 13, 2022 07:32
-
-
Save YuchenJin/14cb5c8791d47e98203aba32b130d8fc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tvm.script import relax as R | |
from tvm import relax | |
resnet_mod_text = """ | |
@tvm.script.ir_module | |
class Module: | |
@tir.prim_func | |
def batch_norm2(rxplaceholder_5: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 56, 56), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm2", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 56, 56): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 56, 56): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(64): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(64, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(64): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(64, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def conv2d_nchw1(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw1", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw(rxplaceholder_2: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_3: tir.Buffer[(64, 3, 7, 7), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 3, 230, 230], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 3, 230, 230): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 3 and i2_2 < 227 and i3_2 >= 3 and i3_2 < 227, rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3], tir.float32(0), dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 112, 112, 3, 7, 7): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def add(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(1, 256, 56, 56), "float32"], T_add_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "add", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_1[ax0, ax1, ax2, ax3]) | |
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3] | |
@tir.prim_func | |
def conv2d_nchw5(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(128, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw5", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 256, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def relu(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], T_relu_1: tir.Buffer[(1, 64, 112, 112), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def batch_norm5(rxplaceholder_5: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 28, 28), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm5", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(512): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(512, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(512): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(512, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def relu6(rxplaceholder_1: tir.Buffer[(1, 1024, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu6", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def conv2d_nchw18(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(2048, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw18", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 1024, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def dense(rxplaceholder_2: tir.Buffer[(1, 2048), "float32"], rxplaceholder_3: tir.Buffer[(1000, 2048), "float32"], T_matmul_NT_1: tir.Buffer[(1, 1000), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "dense", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2 in tir.grid(1, 1000, 2048): | |
with tir.block("T_matmul_NT"): | |
i, j, k = tir.axis.remap("SSR", [i0, i1, i2]) | |
tir.reads(T_matmul_NT_1[i, j], rxplaceholder_2[i, k], rxplaceholder_3[j, k]) | |
tir.writes(T_matmul_NT_1[i, j]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
T_matmul_NT_1[i, j] = tir.float32(0) | |
T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + rxplaceholder_2[i, k] * rxplaceholder_3[j, k] | |
@tir.prim_func | |
def conv2d_nchw15(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(512, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw15", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 1024, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw13(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1024, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw13", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 512, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def add3(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(1, 2048, 7, 7), "float32"], T_add_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "add3", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_1[ax0, ax1, ax2, ax3]) | |
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3] | |
@tir.prim_func | |
def bias_add(rxplaceholder_2: tir.Buffer[(1, 1000), "float32"], rxplaceholder_3: tir.Buffer[(1000,), "float32"], T_add_1: tir.Buffer[(1, 1000), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "bias_add", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1 in tir.grid(1, 1000): | |
with tir.block("T_add"): | |
ax0, ax1 = tir.axis.remap("SS", [i0, i1]) | |
tir.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1]) | |
tir.writes(T_add_1[ax0, ax1]) | |
T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1] | |
@tir.prim_func | |
def batch_norm(rxplaceholder_5: tir.Buffer[(1, 3, 224, 224), "float32"], rxplaceholder_6: tir.Buffer[(3,), "float32"], rxplaceholder_7: tir.Buffer[(3,), "float32"], rxplaceholder_8: tir.Buffer[(3,), "float32"], rxplaceholder_9: tir.Buffer[(3,), "float32"], T_add_2: tir.Buffer[(1, 3, 224, 224), "float32"], T_multiply_2: tir.Buffer[(3,), "float32"], T_multiply_3: tir.Buffer[(3,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32") | |
T_reshape_4 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 3, 224, 224], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 3, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 3]) | |
tir.writes(T_reshape_3[ax0, ax1, ax2, ax3]) | |
T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 3 + ax1 + ax2 + ax3) % 3] | |
for i0, i1, i2, i3 in tir.grid(1, 3, 224, 224): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_3[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_3[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 3]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 3 + ax1 + ax2 + ax3) % 3] | |
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_4[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_4[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 224, 224): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 3, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 3]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 3 + ax1 + ax2 + ax3) % 3] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 3, 224, 224): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_5[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T_reshape_5[ax0, ax1, 0, 0] | |
for i0_6 in tir.serial(3): | |
with tir.block("T_multiply"): | |
ax0 = tir.axis.spatial(3, i0_6) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_2[ax0]) | |
T_multiply_2[ax0] = rxplaceholder_8[ax0] | |
for i0_7 in tir.serial(3): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(3, i0_7) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def relu5(rxplaceholder_1: tir.Buffer[(1, 256, 14, 14), "float32"], T_relu_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu5", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def conv2d_nchw8(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(512, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw8", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 256, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def relu1(rxplaceholder_1: tir.Buffer[(1, 64, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu1", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def conv2d_nchw10(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(256, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw10", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 512, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw19(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 2048, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw19", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 2048, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw6(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 128, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw6", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 128, 30, 30], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 128, 30, 30): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 29 and i3_2 >= 1 and i3_2 < 29, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 128, 3, 3): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def batch_norm3(rxplaceholder_5: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 56, 56), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm3", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(256): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(256, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(256): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(256, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def batch_norm4(rxplaceholder_5: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_6: tir.Buffer[(128,), "float32"], rxplaceholder_7: tir.Buffer[(128,), "float32"], rxplaceholder_8: tir.Buffer[(128,), "float32"], rxplaceholder_9: tir.Buffer[(128,), "float32"], T_add_2: tir.Buffer[(1, 128, 28, 28), "float32"], T_multiply_3: tir.Buffer[(128,), "float32"], T_multiply_4: tir.Buffer[(128,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm4", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 128, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 128]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 128 + ax1 + ax2 + ax3) % 128] | |
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 128]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 128 + ax1 + ax2 + ax3) % 128] | |
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 128, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 128]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 128 + ax1 + ax2 + ax3) % 128] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 128, 28, 28): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 128, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 128]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 128 + ax1 + ax2 + ax3) % 128] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 128, 28, 28): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(128): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(128, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(128): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(128, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def conv2d_nchw3(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(256, 64, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw3", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 56, 56, 64, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw12(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1024, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw12", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 256, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw14(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 1024, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw14", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 1024, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw16(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(512, 512, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw16", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 512, 9, 9], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 9, 9): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 8 and i3_2 >= 1 and i3_2 < 8, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 512, 3, 3): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def batch_norm9(rxplaceholder_5: tir.Buffer[(1, 2048, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(2048,), "float32"], rxplaceholder_7: tir.Buffer[(2048,), "float32"], rxplaceholder_8: tir.Buffer[(2048,), "float32"], rxplaceholder_9: tir.Buffer[(2048,), "float32"], T_add_2: tir.Buffer[(1, 2048, 7, 7), "float32"], T_multiply_3: tir.Buffer[(2048,), "float32"], T_multiply_4: tir.Buffer[(2048,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm9", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 2048, 7, 7], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 2048]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048] | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 2048]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048] | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 2048, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 2048]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 2048, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 2048]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 2048 + ax1 + ax2 + ax3) % 2048] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(2048): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(2048, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(2048): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(2048, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def batch_norm1(rxplaceholder_5: tir.Buffer[(1, 64, 112, 112), "float32"], rxplaceholder_6: tir.Buffer[(64,), "float32"], rxplaceholder_7: tir.Buffer[(64,), "float32"], rxplaceholder_8: tir.Buffer[(64,), "float32"], rxplaceholder_9: tir.Buffer[(64,), "float32"], T_add_2: tir.Buffer[(1, 64, 112, 112), "float32"], T_multiply_3: tir.Buffer[(64,), "float32"], T_multiply_4: tir.Buffer[(64,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm1", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 64, 112, 112], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 112, 112): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 112, 112): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 64 + ax1 + ax2 + ax3) % 64] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 112, 112): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(64): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(64, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(64): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(64, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@relax.function | |
def main(data: Tensor[(1, 3, 224, 224), "float32"], bn_data_gamma: Tensor[(3,), "float32"], bn_data_beta: Tensor[(3,), "float32"], bn_data_moving_mean: Tensor[(3,), "float32"], bn_data_moving_var: Tensor[(3,), "float32"], conv0_weight: Tensor[(64, 3, 7, 7), "float32"], bn0_gamma: Tensor[(64,), "float32"], bn0_beta: Tensor[(64,), "float32"], bn0_moving_mean: Tensor[(64,), "float32"], bn0_moving_var: Tensor[(64,), "float32"], stage1_unit1_bn1_gamma: Tensor[(64,), "float32"], stage1_unit1_bn1_beta: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn1_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv1_weight: Tensor[(64, 64, 1, 1), "float32"], stage1_unit1_bn2_gamma: Tensor[(64,), "float32"], stage1_unit1_bn2_beta: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit1_bn3_gamma: Tensor[(64,), "float32"], stage1_unit1_bn3_beta: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit1_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit1_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit1_sc_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit2_bn1_gamma: Tensor[(256,), "float32"], stage1_unit2_bn1_beta: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit2_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit2_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit2_bn2_gamma: Tensor[(64,), "float32"], stage1_unit2_bn2_beta: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit2_bn3_gamma: Tensor[(64,), "float32"], stage1_unit2_bn3_beta: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit2_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit2_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage1_unit3_bn1_gamma: Tensor[(256,), "float32"], stage1_unit3_bn1_beta: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_mean: Tensor[(256,), "float32"], stage1_unit3_bn1_moving_var: Tensor[(256,), "float32"], stage1_unit3_conv1_weight: Tensor[(64, 256, 1, 1), "float32"], stage1_unit3_bn2_gamma: Tensor[(64,), "float32"], stage1_unit3_bn2_beta: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn2_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv2_weight: Tensor[(64, 64, 3, 3), "float32"], stage1_unit3_bn3_gamma: Tensor[(64,), "float32"], stage1_unit3_bn3_beta: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_mean: Tensor[(64,), "float32"], stage1_unit3_bn3_moving_var: Tensor[(64,), "float32"], stage1_unit3_conv3_weight: Tensor[(256, 64, 1, 1), "float32"], stage2_unit1_bn1_gamma: Tensor[(256,), "float32"], stage2_unit1_bn1_beta: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_mean: Tensor[(256,), "float32"], stage2_unit1_bn1_moving_var: Tensor[(256,), "float32"], stage2_unit1_conv1_weight: Tensor[(128, 256, 1, 1), "float32"], stage2_unit1_bn2_gamma: Tensor[(128,), "float32"], stage2_unit1_bn2_beta: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit1_bn3_gamma: Tensor[(128,), "float32"], stage2_unit1_bn3_beta: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit1_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit1_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit1_sc_weight: Tensor[(512, 256, 1, 1), "float32"], stage2_unit2_bn1_gamma: Tensor[(512,), "float32"], stage2_unit2_bn1_beta: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit2_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit2_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit2_bn2_gamma: Tensor[(128,), "float32"], stage2_unit2_bn2_beta: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit2_bn3_gamma: Tensor[(128,), "float32"], stage2_unit2_bn3_beta: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit2_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit2_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit3_bn1_gamma: Tensor[(512,), "float32"], stage2_unit3_bn1_beta: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit3_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit3_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit3_bn2_gamma: Tensor[(128,), "float32"], stage2_unit3_bn2_beta: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit3_bn3_gamma: Tensor[(128,), "float32"], stage2_unit3_bn3_beta: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit3_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit3_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage2_unit4_bn1_gamma: Tensor[(512,), "float32"], stage2_unit4_bn1_beta: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_mean: Tensor[(512,), "float32"], stage2_unit4_bn1_moving_var: Tensor[(512,), "float32"], stage2_unit4_conv1_weight: Tensor[(128, 512, 1, 1), "float32"], stage2_unit4_bn2_gamma: Tensor[(128,), "float32"], stage2_unit4_bn2_beta: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn2_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv2_weight: Tensor[(128, 128, 3, 3), "float32"], stage2_unit4_bn3_gamma: Tensor[(128,), "float32"], stage2_unit4_bn3_beta: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_mean: Tensor[(128,), "float32"], stage2_unit4_bn3_moving_var: Tensor[(128,), "float32"], stage2_unit4_conv3_weight: Tensor[(512, 128, 1, 1), "float32"], stage3_unit1_bn1_gamma: Tensor[(512,), "float32"], stage3_unit1_bn1_beta: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_mean: Tensor[(512,), "float32"], stage3_unit1_bn1_moving_var: Tensor[(512,), "float32"], stage3_unit1_conv1_weight: Tensor[(256, 512, 1, 1), "float32"], stage3_unit1_bn2_gamma: Tensor[(256,), "float32"], stage3_unit1_bn2_beta: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit1_bn3_gamma: Tensor[(256,), "float32"], stage3_unit1_bn3_beta: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit1_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit1_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit1_sc_weight: Tensor[(1024, 512, 1, 1), "float32"], stage3_unit2_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit2_bn1_beta: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit2_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit2_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit2_bn2_gamma: Tensor[(256,), "float32"], stage3_unit2_bn2_beta: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit2_bn3_gamma: Tensor[(256,), "float32"], stage3_unit2_bn3_beta: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit2_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit2_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit3_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit3_bn1_beta: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit3_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit3_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit3_bn2_gamma: Tensor[(256,), "float32"], stage3_unit3_bn2_beta: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit3_bn3_gamma: Tensor[(256,), "float32"], stage3_unit3_bn3_beta: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit3_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit3_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit4_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit4_bn1_beta: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit4_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit4_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit4_bn2_gamma: Tensor[(256,), "float32"], stage3_unit4_bn2_beta: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit4_bn3_gamma: Tensor[(256,), "float32"], stage3_unit4_bn3_beta: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit4_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit4_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit5_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit5_bn1_beta: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit5_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit5_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit5_bn2_gamma: Tensor[(256,), "float32"], stage3_unit5_bn2_beta: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit5_bn3_gamma: Tensor[(256,), "float32"], stage3_unit5_bn3_beta: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit5_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit5_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage3_unit6_bn1_gamma: Tensor[(1024,), "float32"], stage3_unit6_bn1_beta: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_mean: Tensor[(1024,), "float32"], stage3_unit6_bn1_moving_var: Tensor[(1024,), "float32"], stage3_unit6_conv1_weight: Tensor[(256, 1024, 1, 1), "float32"], stage3_unit6_bn2_gamma: Tensor[(256,), "float32"], stage3_unit6_bn2_beta: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn2_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv2_weight: Tensor[(256, 256, 3, 3), "float32"], stage3_unit6_bn3_gamma: Tensor[(256,), "float32"], stage3_unit6_bn3_beta: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_mean: Tensor[(256,), "float32"], stage3_unit6_bn3_moving_var: Tensor[(256,), "float32"], stage3_unit6_conv3_weight: Tensor[(1024, 256, 1, 1), "float32"], stage4_unit1_bn1_gamma: Tensor[(1024,), "float32"], stage4_unit1_bn1_beta: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_mean: Tensor[(1024,), "float32"], stage4_unit1_bn1_moving_var: Tensor[(1024,), "float32"], stage4_unit1_conv1_weight: Tensor[(512, 1024, 1, 1), "float32"], stage4_unit1_bn2_gamma: Tensor[(512,), "float32"], stage4_unit1_bn2_beta: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit1_bn3_gamma: Tensor[(512,), "float32"], stage4_unit1_bn3_beta: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit1_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit1_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit1_sc_weight: Tensor[(2048, 1024, 1, 1), "float32"], stage4_unit2_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit2_bn1_beta: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit2_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit2_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit2_bn2_gamma: Tensor[(512,), "float32"], stage4_unit2_bn2_beta: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit2_bn3_gamma: Tensor[(512,), "float32"], stage4_unit2_bn3_beta: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit2_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit2_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], stage4_unit3_bn1_gamma: Tensor[(2048,), "float32"], stage4_unit3_bn1_beta: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_mean: Tensor[(2048,), "float32"], stage4_unit3_bn1_moving_var: Tensor[(2048,), "float32"], stage4_unit3_conv1_weight: Tensor[(512, 2048, 1, 1), "float32"], stage4_unit3_bn2_gamma: Tensor[(512,), "float32"], stage4_unit3_bn2_beta: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn2_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv2_weight: Tensor[(512, 512, 3, 3), "float32"], stage4_unit3_bn3_gamma: Tensor[(512,), "float32"], stage4_unit3_bn3_beta: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_mean: Tensor[(512,), "float32"], stage4_unit3_bn3_moving_var: Tensor[(512,), "float32"], stage4_unit3_conv3_weight: Tensor[(2048, 512, 1, 1), "float32"], bn1_gamma: Tensor[(2048,), "float32"], bn1_beta: Tensor[(2048,), "float32"], bn1_moving_mean: Tensor[(2048,), "float32"], bn1_moving_var: Tensor[(2048,), "float32"], fc1_weight: Tensor[(1000, 2048), "float32"], fc1_bias: Tensor[(1000,), "float32"]) -> Tensor[_, "float32"]: | |
# block 0 | |
with relax.dataflow(): | |
gv: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 3, 224, 224), (3,), (3,)), batch_norm, (data, bn_data_gamma, bn_data_beta, bn_data_moving_mean, bn_data_moving_var)) | |
gv1: Tensor[(1, 3, 224, 224), "float32"] = gv[0] | |
gv2: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), conv2d_nchw, (gv1, conv0_weight)) | |
gv3: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 112, 112), (64,), (64,)), batch_norm1, (gv2, bn0_gamma, bn0_beta, bn0_moving_mean, bn0_moving_var)) | |
gv4: Tensor[(1, 64, 112, 112), "float32"] = gv3[0] | |
gv5: Tensor[(1, 64, 112, 112), "float32"] = relax.call_tir((1, 64, 112, 112), relu, (gv4,)) | |
gv6: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), max_pool2d, (gv5,)) | |
gv7: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv6, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta, stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var)) | |
gv8: Tensor[(1, 64, 56, 56), "float32"] = gv7[0] | |
gv9: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv8,)) | |
gv10: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw1, (gv9, stage1_unit1_conv1_weight)) | |
gv11: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv10, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta, stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var)) | |
gv12: Tensor[(1, 64, 56, 56), "float32"] = gv11[0] | |
gv13: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv12,)) | |
gv14: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv13, stage1_unit1_conv2_weight)) | |
gv15: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv14, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta, stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var)) | |
gv16: Tensor[(1, 64, 56, 56), "float32"] = gv15[0] | |
gv17: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv16,)) | |
gv18: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv17, stage1_unit1_conv3_weight)) | |
gv19: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv9, stage1_unit1_sc_weight)) | |
gv20: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv18, gv19)) | |
gv21: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv20, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta, stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var)) | |
gv22: Tensor[(1, 256, 56, 56), "float32"] = gv21[0] | |
gv23: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv22,)) | |
gv24: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (gv23, stage1_unit2_conv1_weight)) | |
gv25: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv24, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta, stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var)) | |
gv26: Tensor[(1, 64, 56, 56), "float32"] = gv25[0] | |
gv27: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv26,)) | |
gv28: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv27, stage1_unit2_conv2_weight)) | |
gv29: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv28, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta, stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var)) | |
gv30: Tensor[(1, 64, 56, 56), "float32"] = gv29[0] | |
gv31: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv30,)) | |
gv32: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv31, stage1_unit2_conv3_weight)) | |
gv33: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv32, gv20)) | |
gv34: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv33, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta, stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var)) | |
gv35: Tensor[(1, 256, 56, 56), "float32"] = gv34[0] | |
gv36: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv35,)) | |
gv37: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw4, (gv36, stage1_unit3_conv1_weight)) | |
gv38: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv37, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta, stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var)) | |
gv39: Tensor[(1, 64, 56, 56), "float32"] = gv38[0] | |
gv40: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv39,)) | |
gv41: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), conv2d_nchw2, (gv40, stage1_unit3_conv2_weight)) | |
gv42: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 64, 56, 56), (64,), (64,)), batch_norm2, (gv41, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta, stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var)) | |
gv43: Tensor[(1, 64, 56, 56), "float32"] = gv42[0] | |
gv44: Tensor[(1, 64, 56, 56), "float32"] = relax.call_tir((1, 64, 56, 56), relu1, (gv43,)) | |
gv45: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), conv2d_nchw3, (gv44, stage1_unit3_conv3_weight)) | |
gv46: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), add, (gv45, gv33)) | |
gv47: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 56, 56), (256,), (256,)), batch_norm3, (gv46, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta, stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var)) | |
gv48: Tensor[(1, 256, 56, 56), "float32"] = gv47[0] | |
gv49: Tensor[(1, 256, 56, 56), "float32"] = relax.call_tir((1, 256, 56, 56), relu2, (gv48,)) | |
gv50: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw5, (gv49, stage2_unit1_conv1_weight)) | |
gv51: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv50, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta, stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var)) | |
gv52: Tensor[(1, 128, 28, 28), "float32"] = gv51[0] | |
gv53: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv52,)) | |
gv54: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv53, stage2_unit1_conv2_weight)) | |
gv55: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv54, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta, stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var)) | |
gv56: Tensor[(1, 128, 28, 28), "float32"] = gv55[0] | |
gv57: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv56,)) | |
gv58: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv57, stage2_unit1_conv3_weight)) | |
gv59: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw8, (gv49, stage2_unit1_sc_weight)) | |
gv60: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv58, gv59)) | |
gv61: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv60, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta, stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var)) | |
gv62: Tensor[(1, 512, 28, 28), "float32"] = gv61[0] | |
gv63: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv62,)) | |
gv64: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv63, stage2_unit2_conv1_weight)) | |
gv65: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv64, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta, stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var)) | |
gv66: Tensor[(1, 128, 28, 28), "float32"] = gv65[0] | |
gv67: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv66,)) | |
gv68: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv67, stage2_unit2_conv2_weight)) | |
gv69: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv68, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta, stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var)) | |
gv70: Tensor[(1, 128, 28, 28), "float32"] = gv69[0] | |
gv71: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv70,)) | |
gv72: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv71, stage2_unit2_conv3_weight)) | |
gv73: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv72, gv60)) | |
gv74: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv73, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta, stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var)) | |
gv75: Tensor[(1, 512, 28, 28), "float32"] = gv74[0] | |
gv76: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv75,)) | |
gv77: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv76, stage2_unit3_conv1_weight)) | |
gv78: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv77, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta, stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var)) | |
gv79: Tensor[(1, 128, 28, 28), "float32"] = gv78[0] | |
gv80: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv79,)) | |
gv81: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv80, stage2_unit3_conv2_weight)) | |
gv82: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv81, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta, stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var)) | |
gv83: Tensor[(1, 128, 28, 28), "float32"] = gv82[0] | |
gv84: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv83,)) | |
gv85: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv84, stage2_unit3_conv3_weight)) | |
gv86: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv85, gv73)) | |
gv87: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv86, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta, stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var)) | |
gv88: Tensor[(1, 512, 28, 28), "float32"] = gv87[0] | |
gv89: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv88,)) | |
gv90: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw9, (gv89, stage2_unit4_conv1_weight)) | |
gv91: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv90, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta, stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var)) | |
gv92: Tensor[(1, 128, 28, 28), "float32"] = gv91[0] | |
gv93: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv92,)) | |
gv94: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), conv2d_nchw6, (gv93, stage2_unit4_conv2_weight)) | |
gv95: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 128, 28, 28), (128,), (128,)), batch_norm4, (gv94, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta, stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var)) | |
gv96: Tensor[(1, 128, 28, 28), "float32"] = gv95[0] | |
gv97: Tensor[(1, 128, 28, 28), "float32"] = relax.call_tir((1, 128, 28, 28), relu3, (gv96,)) | |
gv98: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), conv2d_nchw7, (gv97, stage2_unit4_conv3_weight)) | |
gv99: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), add1, (gv98, gv86)) | |
gv100: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 28, 28), (512,), (512,)), batch_norm5, (gv99, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta, stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var)) | |
gv101: Tensor[(1, 512, 28, 28), "float32"] = gv100[0] | |
gv102: Tensor[(1, 512, 28, 28), "float32"] = relax.call_tir((1, 512, 28, 28), relu4, (gv101,)) | |
gv103: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw10, (gv102, stage3_unit1_conv1_weight)) | |
gv104: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv103, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta, stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var)) | |
gv105: Tensor[(1, 256, 14, 14), "float32"] = gv104[0] | |
gv106: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv105,)) | |
gv107: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv106, stage3_unit1_conv2_weight)) | |
gv108: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv107, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta, stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var)) | |
gv109: Tensor[(1, 256, 14, 14), "float32"] = gv108[0] | |
gv110: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv109,)) | |
gv111: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv110, stage3_unit1_conv3_weight)) | |
gv112: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw13, (gv102, stage3_unit1_sc_weight)) | |
gv113: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv111, gv112)) | |
gv114: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv113, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta, stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var)) | |
gv115: Tensor[(1, 1024, 14, 14), "float32"] = gv114[0] | |
gv116: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv115,)) | |
gv117: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv116, stage3_unit2_conv1_weight)) | |
gv118: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv117, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta, stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var)) | |
gv119: Tensor[(1, 256, 14, 14), "float32"] = gv118[0] | |
gv120: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv119,)) | |
gv121: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv120, stage3_unit2_conv2_weight)) | |
gv122: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv121, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta, stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var)) | |
gv123: Tensor[(1, 256, 14, 14), "float32"] = gv122[0] | |
gv124: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv123,)) | |
gv125: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv124, stage3_unit2_conv3_weight)) | |
gv126: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv125, gv113)) | |
gv127: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv126, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta, stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var)) | |
gv128: Tensor[(1, 1024, 14, 14), "float32"] = gv127[0] | |
gv129: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv128,)) | |
gv130: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv129, stage3_unit3_conv1_weight)) | |
gv131: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv130, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta, stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var)) | |
gv132: Tensor[(1, 256, 14, 14), "float32"] = gv131[0] | |
gv133: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv132,)) | |
gv134: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv133, stage3_unit3_conv2_weight)) | |
gv135: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv134, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta, stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var)) | |
gv136: Tensor[(1, 256, 14, 14), "float32"] = gv135[0] | |
gv137: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv136,)) | |
gv138: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv137, stage3_unit3_conv3_weight)) | |
gv139: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv138, gv126)) | |
gv140: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv139, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta, stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var)) | |
gv141: Tensor[(1, 1024, 14, 14), "float32"] = gv140[0] | |
gv142: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv141,)) | |
gv143: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv142, stage3_unit4_conv1_weight)) | |
gv144: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv143, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta, stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var)) | |
gv145: Tensor[(1, 256, 14, 14), "float32"] = gv144[0] | |
gv146: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv145,)) | |
gv147: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv146, stage3_unit4_conv2_weight)) | |
gv148: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv147, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta, stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var)) | |
gv149: Tensor[(1, 256, 14, 14), "float32"] = gv148[0] | |
gv150: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv149,)) | |
gv151: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv150, stage3_unit4_conv3_weight)) | |
gv152: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv151, gv139)) | |
gv153: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv152, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta, stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var)) | |
gv154: Tensor[(1, 1024, 14, 14), "float32"] = gv153[0] | |
gv155: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv154,)) | |
gv156: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv155, stage3_unit5_conv1_weight)) | |
gv157: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv156, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta, stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var)) | |
gv158: Tensor[(1, 256, 14, 14), "float32"] = gv157[0] | |
gv159: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv158,)) | |
gv160: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv159, stage3_unit5_conv2_weight)) | |
gv161: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv160, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta, stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var)) | |
gv162: Tensor[(1, 256, 14, 14), "float32"] = gv161[0] | |
gv163: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv162,)) | |
gv164: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv163, stage3_unit5_conv3_weight)) | |
gv165: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv164, gv152)) | |
gv166: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv165, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta, stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var)) | |
gv167: Tensor[(1, 1024, 14, 14), "float32"] = gv166[0] | |
gv168: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv167,)) | |
gv169: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw14, (gv168, stage3_unit6_conv1_weight)) | |
gv170: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv169, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta, stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var)) | |
gv171: Tensor[(1, 256, 14, 14), "float32"] = gv170[0] | |
gv172: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv171,)) | |
gv173: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), conv2d_nchw11, (gv172, stage3_unit6_conv2_weight)) | |
gv174: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 256, 14, 14), (256,), (256,)), batch_norm6, (gv173, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta, stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var)) | |
gv175: Tensor[(1, 256, 14, 14), "float32"] = gv174[0] | |
gv176: Tensor[(1, 256, 14, 14), "float32"] = relax.call_tir((1, 256, 14, 14), relu5, (gv175,)) | |
gv177: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), conv2d_nchw12, (gv176, stage3_unit6_conv3_weight)) | |
gv178: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), add2, (gv177, gv165)) | |
gv179: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 1024, 14, 14), (1024,), (1024,)), batch_norm7, (gv178, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta, stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var)) | |
gv180: Tensor[(1, 1024, 14, 14), "float32"] = gv179[0] | |
gv181: Tensor[(1, 1024, 14, 14), "float32"] = relax.call_tir((1, 1024, 14, 14), relu6, (gv180,)) | |
gv182: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw15, (gv181, stage4_unit1_conv1_weight)) | |
gv183: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv182, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta, stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var)) | |
gv184: Tensor[(1, 512, 7, 7), "float32"] = gv183[0] | |
gv185: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv184,)) | |
gv186: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv185, stage4_unit1_conv2_weight)) | |
gv187: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv186, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta, stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var)) | |
gv188: Tensor[(1, 512, 7, 7), "float32"] = gv187[0] | |
gv189: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv188,)) | |
gv190: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv189, stage4_unit1_conv3_weight)) | |
gv191: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw18, (gv181, stage4_unit1_sc_weight)) | |
gv192: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv190, gv191)) | |
gv193: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv192, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta, stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var)) | |
gv194: Tensor[(1, 2048, 7, 7), "float32"] = gv193[0] | |
gv195: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv194,)) | |
gv196: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (gv195, stage4_unit2_conv1_weight)) | |
gv197: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv196, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta, stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var)) | |
gv198: Tensor[(1, 512, 7, 7), "float32"] = gv197[0] | |
gv199: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv198,)) | |
gv200: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv199, stage4_unit2_conv2_weight)) | |
gv201: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv200, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta, stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var)) | |
gv202: Tensor[(1, 512, 7, 7), "float32"] = gv201[0] | |
gv203: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv202,)) | |
gv204: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv203, stage4_unit2_conv3_weight)) | |
gv205: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv204, gv192)) | |
gv206: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv205, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta, stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var)) | |
gv207: Tensor[(1, 2048, 7, 7), "float32"] = gv206[0] | |
gv208: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv207,)) | |
gv209: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw19, (gv208, stage4_unit3_conv1_weight)) | |
gv210: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv209, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta, stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var)) | |
gv211: Tensor[(1, 512, 7, 7), "float32"] = gv210[0] | |
gv212: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv211,)) | |
gv213: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), conv2d_nchw16, (gv212, stage4_unit3_conv2_weight)) | |
gv214: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 512, 7, 7), (512,), (512,)), batch_norm8, (gv213, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta, stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var)) | |
gv215: Tensor[(1, 512, 7, 7), "float32"] = gv214[0] | |
gv216: Tensor[(1, 512, 7, 7), "float32"] = relax.call_tir((1, 512, 7, 7), relu7, (gv215,)) | |
gv217: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), conv2d_nchw17, (gv216, stage4_unit3_conv3_weight)) | |
gv218: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), add3, (gv217, gv205)) | |
gv219: Tuple[Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"], Tensor[(_, _, _, _), "float32"]] = relax.call_tir(((1, 2048, 7, 7), (2048,), (2048,)), batch_norm9, (gv218, bn1_gamma, bn1_beta, bn1_moving_mean, bn1_moving_var)) | |
gv220: Tensor[(1, 2048, 7, 7), "float32"] = gv219[0] | |
gv221: Tensor[(1, 2048, 7, 7), "float32"] = relax.call_tir((1, 2048, 7, 7), relu8, (gv220,)) | |
gv222: Tensor[(1, 2048, 1, 1), "float32"] = relax.call_tir((1, 2048, 1, 1), global_avg_pool2d, (gv221,)) | |
gv223: Tensor[(1, 2048), "float32"] = relax.call_tir((1, 2048), batch_flatten, (gv222,)) | |
gv224: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), dense, (gv223, fc1_weight)) | |
gv225: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), bias_add, (gv224, fc1_bias)) | |
gv226: Tensor[(1, 1000), "float32"] = relax.call_tir((1, 1000), softmax, (gv225,)) | |
relax.output(gv226) | |
return gv226 | |
@tir.prim_func | |
def relu8(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu8", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def batch_flatten(rxplaceholder_1: tir.Buffer[(1, 2048, 1, 1), "float32"], tensor_1: tir.Buffer[(1, 2048), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_flatten", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1 in tir.grid(1, 2048): | |
with tir.block("tensor"): | |
ax0, ax1 = tir.axis.remap("SS", [i0, i1]) | |
tir.reads(rxplaceholder_1[ax0, ax1 % 2048, 0, 0]) | |
tir.writes(tensor_1[ax0, ax1]) | |
tensor_1[ax0, ax1] = rxplaceholder_1[ax0, ax1 % 2048, 0, 0] | |
@tir.prim_func | |
def conv2d_nchw2(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 64, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw2", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 64, 58, 58], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 58, 58): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 57 and i3_2 >= 1 and i3_2 < 57, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 3, 3): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw7(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(512, 128, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw7", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 128, 28, 28], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 128, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def conv2d_nchw11(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(256, 256, 3, 3), "float32"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw11", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 256, 16, 16], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 16, 16): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(i2_2 >= 1 and i2_2 < 15 and i3_2 >= 1 and i3_2 < 15, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 256, 3, 3): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def relu2(rxplaceholder_1: tir.Buffer[(1, 256, 56, 56), "float32"], T_relu_1: tir.Buffer[(1, 256, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu2", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def relu3(rxplaceholder_1: tir.Buffer[(1, 128, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu3", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def add1(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(1, 512, 28, 28), "float32"], T_add_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "add1", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_1[ax0, ax1, ax2, ax3]) | |
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3] | |
@tir.prim_func | |
def conv2d_nchw4(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), "float32"], rxplaceholder_3: tir.Buffer[(64, 256, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw4", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 256, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def batch_norm6(rxplaceholder_5: tir.Buffer[(1, 256, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(256,), "float32"], rxplaceholder_7: tir.Buffer[(256,), "float32"], rxplaceholder_8: tir.Buffer[(256,), "float32"], rxplaceholder_9: tir.Buffer[(256,), "float32"], T_add_2: tir.Buffer[(1, 256, 14, 14), "float32"], T_multiply_3: tir.Buffer[(256,), "float32"], T_multiply_4: tir.Buffer[(256,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm6", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 256, 14, 14], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 14, 14): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 256 + ax1 + ax2 + ax3) % 256] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 14, 14): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(256): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(256, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(256): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(256, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def conv2d_nchw17(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_3: tir.Buffer[(2048, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw17", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 512, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def batch_norm7(rxplaceholder_5: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_6: tir.Buffer[(1024,), "float32"], rxplaceholder_7: tir.Buffer[(1024,), "float32"], rxplaceholder_8: tir.Buffer[(1024,), "float32"], rxplaceholder_9: tir.Buffer[(1024,), "float32"], T_add_2: tir.Buffer[(1, 1024, 14, 14), "float32"], T_multiply_3: tir.Buffer[(1024,), "float32"], T_multiply_4: tir.Buffer[(1024,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm7", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 1024, 14, 14], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 1024, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 1024]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024] | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 1024]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024] | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 1024, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 1024]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 1024, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 1024]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 1024 + ax1 + ax2 + ax3) % 1024] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(1024): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(1024, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(1024): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(1024, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def conv2d_nchw9(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), "float32"], rxplaceholder_3: tir.Buffer[(128, 512, 1, 1), "float32"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "conv2d_nchw9", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("pad_temp"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2]) | |
pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2] | |
for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 512, 1, 1): | |
with tir.block("conv2d_nchw"): | |
nn, ff, yy, xx, rc, ry, rx = tir.axis.remap("SSSSRRR", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6]) | |
tir.reads(conv2d_nchw_1[nn, ff, yy, xx], pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx]) | |
tir.writes(conv2d_nchw_1[nn, ff, yy, xx]) | |
tir.block_attr({"layout_free_placeholders":[rxplaceholder_3]}) | |
with tir.init(): | |
conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0) | |
conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx] | |
@tir.prim_func | |
def global_avg_pool2d(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), "float32"], tensor_2: tir.Buffer[(1, 2048, 1, 1), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "global_avg_pool2d", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
tensor_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype="float32") | |
for i0, i1, i2, i3, i4, i5 in tir.grid(1, 2048, 1, 1, 7, 7): | |
with tir.block("tensor"): | |
ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) | |
tir.reads(tensor_3[ax0, ax1, ax2, ax3], rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]) | |
tir.writes(tensor_3[ax0, ax1, ax2, ax3]) | |
with tir.init(): | |
tensor_3[ax0, ax1, ax2, ax3] = tir.float32(0) | |
tensor_3[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1] | |
for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1): | |
with tir.block("tensor_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(tensor_3[ax0, ax1, ax2, ax3]) | |
tir.writes(tensor_2[ax0, ax1, ax2, ax3]) | |
tensor_2[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] / (tir.cast(tir.Select(True, (ax2 + 1) * 7, (ax2 + 1) * 7 + 1) - ax2 * 7, "float32") * tir.cast(tir.Select(True, (ax3 + 1) * 7, (ax3 + 1) * 7 + 1) - ax3 * 7, "float32")) | |
@tir.prim_func | |
def relu7(rxplaceholder_1: tir.Buffer[(1, 512, 7, 7), "float32"], T_relu_1: tir.Buffer[(1, 512, 7, 7), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu7", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def softmax(rxplaceholder_1: tir.Buffer[(1, 1000), "float32"], T_softmax_norm_1: tir.Buffer[(1, 1000), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "softmax", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_softmax_maxelem_1 = tir.alloc_buffer([1], dtype="float32") | |
T_softmax_exp_1 = tir.alloc_buffer([1, 1000], dtype="float32") | |
T_softmax_expsum_1 = tir.alloc_buffer([1], dtype="float32") | |
for i0_7, i1_3 in tir.grid(1, 1000): | |
with tir.block("T_softmax_maxelem"): | |
i0_8, k = tir.axis.remap("SR", [i0_7, i1_3]) | |
tir.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) | |
tir.writes(T_softmax_maxelem_1[i0_8]) | |
with tir.init(): | |
T_softmax_maxelem_1[i0_8] = tir.float32(-3.4028234663852886e+38) | |
T_softmax_maxelem_1[i0_8] = tir.max(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]) | |
for i0_9, i1_4 in tir.grid(1, 1000): | |
with tir.block("T_softmax_exp"): | |
i0_10, i1_5 = tir.axis.remap("SS", [i0_9, i1_4]) | |
tir.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10]) | |
tir.writes(T_softmax_exp_1[i0_10, i1_5]) | |
T_softmax_exp_1[i0_10, i1_5] = tir.exp(rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32") | |
for i0_11, i1_6 in tir.grid(1, 1000): | |
with tir.block("T_softmax_expsum"): | |
i0_12, k = tir.axis.remap("SR", [i0_11, i1_6]) | |
tir.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k]) | |
tir.writes(T_softmax_expsum_1[i0_12]) | |
with tir.init(): | |
T_softmax_expsum_1[i0_12] = tir.float32(0) | |
T_softmax_expsum_1[i0_12] = T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k] | |
for i0_13, i1_7 in tir.grid(1, 1000): | |
with tir.block("T_softmax_norm"): | |
i0_14, i1_8 = tir.axis.remap("SS", [i0_13, i1_7]) | |
tir.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14]) | |
tir.writes(T_softmax_norm_1[i0_14, i1_8]) | |
tir.block_attr({"axis":1}) | |
T_softmax_norm_1[i0_14, i1_8] = T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14] | |
@tir.prim_func | |
def relu4(rxplaceholder_1: tir.Buffer[(1, 512, 28, 28), "float32"], T_relu_1: tir.Buffer[(1, 512, 28, 28), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "relu4", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28): | |
with tir.block("T_relu"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3]) | |
tir.writes(T_relu_1[ax0, ax1, ax2, ax3]) | |
T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0)) | |
@tir.prim_func | |
def max_pool2d(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), "float32"], tensor_1: tir.Buffer[(1, 64, 56, 56), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "max_pool2d", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
pad_temp_1 = tir.alloc_buffer([1, 64, 114, 114], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 64, 114, 114): | |
with tir.block("pad_temp"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1]) | |
tir.writes(pad_temp_1[ax0, ax1, ax2, ax3]) | |
pad_temp_1[ax0, ax1, ax2, ax3] = tir.if_then_else(ax2 >= 1 and ax2 < 113 and ax3 >= 1 and ax3 < 113, rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], tir.float32(-3.4028234663852886e+38), dtype="float32") | |
for i0, i1, i2, i3, i4, i5 in tir.grid(1, 64, 56, 56, 3, 3): | |
with tir.block("tensor"): | |
ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) | |
tir.reads(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1]) | |
tir.writes(tensor_1[ax0, ax1, ax2, ax3]) | |
with tir.init(): | |
tensor_1[ax0, ax1, ax2, ax3] = tir.float32(-3.4028234663852886e+38) | |
tensor_1[ax0, ax1, ax2, ax3] = tir.max(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1]) | |
@tir.prim_func | |
def batch_norm8(rxplaceholder_5: tir.Buffer[(1, 512, 7, 7), "float32"], rxplaceholder_6: tir.Buffer[(512,), "float32"], rxplaceholder_7: tir.Buffer[(512,), "float32"], rxplaceholder_8: tir.Buffer[(512,), "float32"], rxplaceholder_9: tir.Buffer[(512,), "float32"], T_add_2: tir.Buffer[(1, 512, 7, 7), "float32"], T_multiply_3: tir.Buffer[(512,), "float32"], T_multiply_4: tir.Buffer[(512,), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "batch_norm8", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_subtract_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32") | |
T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_divide_1 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32") | |
T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
T_multiply_5 = tir.alloc_buffer([1, 512, 7, 7], dtype="float32") | |
T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype="float32") | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_4[ax0, ax1, ax2, ax3]) | |
T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7): | |
with tir.block("T_subtract"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0]) | |
tir.writes(T_subtract_1[ax0, ax1, ax2, ax3]) | |
T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_5[ax0, ax1, ax2, ax3]) | |
T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_reshape_5[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_3[ax0, ax1, ax2, ax3]) | |
T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05) | |
for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1): | |
with tir.block("compute"): | |
i0_2, i1_2, i2_2, i3_2 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2]) | |
tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2]) | |
compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype="float32") | |
for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7): | |
with tir.block("T_divide"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3]) | |
tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0]) | |
tir.writes(T_divide_1[ax0, ax1, ax2, ax3]) | |
T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0] | |
for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_2"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4]) | |
tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_6[ax0, ax1, ax2, ax3]) | |
T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 7, 7): | |
with tir.block("T_multiply"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_5, i1_5, i2_5, i3_5]) | |
tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0]) | |
tir.writes(T_multiply_5[ax0, ax1, ax2, ax3]) | |
T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0] | |
for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1): | |
with tir.block("T_reshape_3"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6]) | |
tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512]) | |
tir.writes(T_reshape_7[ax0, ax1, ax2, ax3]) | |
T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax0 * 512 + ax1 + ax2 + ax3) % 512] | |
for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 7, 7): | |
with tir.block("T_add_1"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7]) | |
tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0]) | |
tir.writes(T_add_2[ax0, ax1, ax2, ax3]) | |
T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0] | |
for i0_8 in tir.serial(512): | |
with tir.block("T_multiply_1"): | |
ax0 = tir.axis.spatial(512, i0_8) | |
tir.reads(rxplaceholder_8[ax0]) | |
tir.writes(T_multiply_3[ax0]) | |
T_multiply_3[ax0] = rxplaceholder_8[ax0] | |
for i0_9 in tir.serial(512): | |
with tir.block("T_multiply_2"): | |
ax0 = tir.axis.spatial(512, i0_9) | |
tir.reads(rxplaceholder_9[ax0]) | |
tir.writes(T_multiply_4[ax0]) | |
T_multiply_4[ax0] = rxplaceholder_9[ax0] | |
@tir.prim_func | |
def add2(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), "float32"], rxplaceholder_3: tir.Buffer[(1, 1024, 14, 14), "float32"], T_add_1: tir.Buffer[(1, 1024, 14, 14), "float32"]) -> None: | |
# function attr dict | |
tir.func_attr({"global_symbol": "add2", "tir.noalias": True}) | |
# body | |
# with tir.block("root") | |
for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14): | |
with tir.block("T_add"): | |
ax0, ax1, ax2, ax3 = tir.axis.remap("SSSS", [i0, i1, i2, i3]) | |
tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3]) | |
tir.writes(T_add_1[ax0, ax1, ax2, ax3]) | |
T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3] | |
""" | |
relax_mod = R.parser.from_source(resnet_mod_text) | |
R.parser.pretty_print(relax_mod["main"]) | |
print(relax_mod["main"].body.blocks[0].bindings[1].var.checked_type) | |
print(relax_mod["main"].body.blocks[0].bindings[1].var.type_annotation) | |
print(relax_mod["main"].body.blocks[0].bindings[1].var.checked_type.rank) | |
print(relax_mod["main"].body.blocks[0].bindings[2].var.checked_type.rank) |
Hzfengsy
commented
Mar 13, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment