Skip to content

Instantly share code, notes, and snippets.

@YuchenJin
Last active October 8, 2024 15:21
Show Gist options
  • Save YuchenJin/56442c4e967f68c20e5777e46fe0a68d to your computer and use it in GitHub Desktop.
Save YuchenJin/56442c4e967f68c20e5777e46fe0a68d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d5783aee",
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations # must import to defer parsing of annotations\n",
"import os\n",
"import numpy as np\n",
"import tvm\n",
"from tvm import relax, tir, topi\n",
"from tvm.runtime import container\n",
"from tvm.target.target import Target\n",
"from tvm.relax.testing import nn\n",
"\n",
"import tvm.script\n",
"from tvm.script import tir as T, relax as R"
]
},
{
"cell_type": "markdown",
"id": "62b4198a",
"metadata": {},
"source": [
"## Build and run a neural network in Relax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd0cc6e9",
"metadata": {},
"outputs": [],
"source": [
"builder = relax.BlockBuilder()\n",
"\n",
"input_size = 784\n",
"hidden_sizes = [128, 64]\n",
"output_size = 10"
]
},
{
"cell_type": "markdown",
"id": "068cded9",
"metadata": {},
"source": [
"Build a three linear-layer neural network for a classification task.\n",
"A neural network is a `nn.Module` that can consist of other modules (layers). This nested structure allows for building complex models easily."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a12e45da",
"metadata": {},
"outputs": [],
"source": [
"with builder.function(name=\"main\"):\n",
" model = nn.Sequential(\n",
" nn.Linear(input_size, hidden_sizes[0]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[1], output_size),\n",
" nn.LogSoftmax(),\n",
" )\n",
" \n",
" # n is a symbolic variable to represent a dynamic batch size\n",
" n = tir.Var(\"n\", \"int64\")\n",
" data = nn.Placeholder((n, input_size), name=\"data\")\n",
" output = model(data)\n",
" params = [data] + model.parameters()\n",
" builder.emit_func_output(output, params=params) "
]
},
{
"cell_type": "markdown",
"id": "abde5a76",
"metadata": {},
"source": [
"Get and print the IRModule being built."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "54f1bb9e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def matmul1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128, 64), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 64], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 64, 128):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n",
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n",
" @tir.prim_func\n",
" def matmul(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(784, 128), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 784], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 128, 784):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n",
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n",
" @tir.prim_func\n",
" def log_softmax(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"log_softmax\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n",
" compute_3 = tir.match_buffer(var_compute_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" compute_4 = tir.alloc_buffer([n_1], dtype=\"float32\")\n",
" compute_5 = tir.alloc_buffer([n_1], dtype=\"float32\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute\"):\n",
" i = tir.axis.spatial(n_1, i0)\n",
" k = tir.axis.reduce(10, i1)\n",
" tir.reads(rxplaceholder_1[i, k])\n",
" tir.writes(compute_4[i])\n",
" with tir.init():\n",
" compute_4[i] = tir.float32(-3.4028234663852886e+38)\n",
" compute_4[i] = tir.max(compute_4[i], rxplaceholder_1[i, k])\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute_1\"):\n",
" i = tir.axis.spatial(n_1, i0)\n",
" k = tir.axis.reduce(10, i1)\n",
" tir.reads(rxplaceholder_1[i, k], compute_4[i])\n",
" tir.writes(compute_5[i])\n",
" with tir.init():\n",
" compute_5[i] = tir.float32(0)\n",
" compute_5[i] = compute_5[i] + tir.exp(rxplaceholder_1[i, k] - compute_4[i], dtype=\"float32\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute_2\"):\n",
" i = tir.axis.spatial(n_1, i0)\n",
" j = tir.axis.spatial(10, i1)\n",
" tir.reads(rxplaceholder_1[i, j], compute_4[i], compute_5[i])\n",
" tir.writes(compute_3[i, j])\n",
" compute_3[i, j] = rxplaceholder_1[i, j] - compute_4[i] - tir.log(compute_5[i], dtype=\"float32\")\n",
" \n",
" @relax.function\n",
" def main(data: Tensor((n, 784), \"float32\"), linear_weight: Tensor((784, 128), \"float32\"), linear_bias: Tensor((128,), \"float32\"), linear_weight1: Tensor((128, 64), \"float32\"), linear_bias1: Tensor((64,), \"float32\"), linear_weight2: Tensor((64, 10), \"float32\"), linear_bias2: Tensor((10,), \"float32\")) -> Tensor(_, \"float32\"):\n",
" # block 0\n",
" gv = relax.call_tir(matmul, (data, linear_weight), (n, 128), dtype=\"float32\")\n",
" gv1 = relax.call_tir(add, (gv, linear_bias), (n, 128), dtype=\"float32\")\n",
" gv2 = relax.call_tir(relu, (gv1,), (n, 128), dtype=\"float32\")\n",
" gv3 = relax.call_tir(matmul1, (gv2, linear_weight1), (n, 64), dtype=\"float32\")\n",
" gv4 = relax.call_tir(add1, (gv3, linear_bias1), (n, 64), dtype=\"float32\")\n",
" gv5 = relax.call_tir(relu1, (gv4,), (n, 64), dtype=\"float32\")\n",
" gv6 = relax.call_tir(matmul2, (gv5, linear_weight2), (n, 10), dtype=\"float32\")\n",
" gv7 = relax.call_tir(add2, (gv6, linear_bias2), (n, 10), dtype=\"float32\")\n",
" gv8 = relax.call_tir(log_softmax, (gv7,), (n, 10), dtype=\"float32\")\n",
" return gv8\n",
" \n",
" @tir.prim_func\n",
" def relu1(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 64], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 64):\n",
" with tir.block(\"compute\"):\n",
" i0_2 = tir.axis.spatial(n_1, i0)\n",
" i1_2 = tir.axis.spatial(64, i1)\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def add2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(10,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"T_add\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1 = tir.axis.spatial(10, i1)\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def add1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(64,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 64], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 64):\n",
" with tir.block(\"T_add\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1 = tir.axis.spatial(64, i1)\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def add(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 128):\n",
" with tir.block(\"T_add\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1 = tir.axis.spatial(128, i1)\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 128):\n",
" with tir.block(\"compute\"):\n",
" i0_2 = tir.axis.spatial(n_1, i0)\n",
" i1_2 = tir.axis.spatial(128, i1)\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def matmul2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(64, 10), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul2\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 10, 64):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0 = tir.axis.spatial(n_1, i0)\n",
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n",
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n"
]
}
],
"source": [
"mod = builder.get()\n",
"R.parser.pretty_print(mod)"
]
},
{
"cell_type": "markdown",
"id": "158875f7",
"metadata": {},
"source": [
"Build the model and create Relax VM."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b6e7a8cd",
"metadata": {},
"outputs": [],
"source": [
"# build and create vm executor\n",
"target = tvm.target.Target(\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())"
]
},
{
"cell_type": "markdown",
"id": "7679f077",
"metadata": {},
"source": [
"Run the model on Relax VM."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1e530019",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-109873.59 -49647. -41688.25 -94398.22 -81715.94 -111547.\n",
" 0. -63041.28 -18236.625 -98630.66 ]\n",
" [-110056.69 -49696.03 -41759.594 -94529.28 -81796.625 -111659.78\n",
" 0. -63112.375 -18241.969 -98737.47 ]\n",
" [-109510.03 -49518.72 -41588.156 -94110.03 -81418.28 -111148.59\n",
" 0. -62861.28 -18204.406 -98300.47 ]]\n"
]
}
],
"source": [
"# init parameters\n",
"params = nn.init_params(mod)\n",
"# the input data has a minibatch size of 3\n",
"data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))\n",
"\n",
"res = vm[\"main\"](data, *params)\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dfe290fb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-113639.97 -51407.094 -43109.22 -97656.59 -84459.34 -115332.\n",
" 0. -65198.156 -18852.219 -102011.34 ]\n",
" [-110885.53 -50159.812 -42090. -95289.5 -82462.25 -112593.28\n",
" 0. -63644.656 -18419.344 -99547.97 ]\n",
" [-111952.81 -50647.22 -42490.312 -96203.5 -83224.56 -113655.34\n",
" 0. -64285.72 -18587.406 -100505.625]\n",
" [-114726.34 -51872.094 -43565.125 -98605.84 -85311.03 -116471.19\n",
" 0. -65849.66 -19040.094 -102954.72 ]\n",
" [-108532.47 -49076.97 -41242.25 -93290.25 -80707.25 -110214.56\n",
" 0. -62300.625 -18015.938 -97434.66 ]]\n"
]
}
],
"source": [
"# the input data has a minibatch size of 5\n",
"data = tvm.nd.array(np.random.rand(5, input_size).astype(np.float32))\n",
"\n",
"res = vm[\"main\"](data, *params)\n",
"print(res)"
]
},
{
"cell_type": "markdown",
"id": "23b4c322",
"metadata": {},
"source": [
"Define a layer/network with emit_te using the nn.Module interface."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eb687277",
"metadata": {},
"outputs": [],
"source": [
"class Linear(nn.Module):\n",
" \"\"\"Applies a linear transformation to the input data: :math:`y = xA + b`.\"\"\"\n",
"\n",
" def __init__(self, in_features, out_features, bias=True):\n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.weight = Parameter((in_features, out_features), name=\"linear_weight\")\n",
" if bias:\n",
" self.bias = Parameter((out_features,), name=\"linear_bias\")\n",
" else:\n",
" self.bias = None\n",
"\n",
" def forward(self, input: relax.Expr) -> relax.Var:\n",
" y = emit_te(topi.matmul, input, self.weight)\n",
" if self.bias is not None:\n",
" y = emit_te(topi.add, y, self.bias)\n",
" return y"
]
},
{
"cell_type": "markdown",
"id": "89e095a9",
"metadata": {},
"source": [
"## TE/TOPI Integration"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7ca3cfdb",
"metadata": {},
"outputs": [],
"source": [
"def build_mlp(data, weight):\n",
" bb = relax.BlockBuilder()\n",
"\n",
" with bb.function(\"mlp\", [data, weight]):\n",
" gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)\n",
" gv1 = bb.emit_te(topi.nn.relu, gv0)\n",
" bb.emit_func_output(gv1)\n",
"\n",
" mod = bb.get()\n",
" return mod"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c5c1e8f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def matmul(var_rxplaceholder_2: tir.handle, var_rxplaceholder_3: tir.handle, var_C_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n",
" m_1 = tir.var(\"int64\")\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_2 = tir.match_buffer(var_rxplaceholder_2, [n_1, m_1], dtype=\"float32\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_3, [m_1, n_1], dtype=\"float32\")\n",
" C_1 = tir.match_buffer(var_C_1, [n_1, n_1], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" tir.evaluate(tir.tvm_call_packed(\"tvm.contrib.cblas.matmul\", tir.tvm_stack_make_array(rxplaceholder_2.data, tir.tvm_stack_make_shape(n_1, m_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(rxplaceholder_3.data, tir.tvm_stack_make_shape(m_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(C_1.data, tir.tvm_stack_make_shape(n_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), False, False, dtype=\"int32\"))\n",
" \n",
" @tir.prim_func\n",
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, n_1], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, n_1], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, n_1):\n",
" with tir.block(\"compute\"):\n",
" i0_2 = tir.axis.spatial(n_1, i0)\n",
" i1_2 = tir.axis.spatial(n_1, i1)\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @relax.function\n",
" def mlp(data: Tensor((n, m), \"float32\"), weight: Tensor((m, n), \"float32\")) -> Tensor(_, \"float32\"):\n",
" # block 0\n",
" gv = relax.call_tir(matmul, (data, weight), (n, n), dtype=\"float32\")\n",
" gv1 = relax.call_tir(relu, (gv,), (n, n), dtype=\"float32\")\n",
" return gv1\n",
" \n"
]
}
],
"source": [
"# symbolic dimensions\n",
"n, m = tir.Var(\"n\", \"int64\"), tir.Var(\"m\", \"int64\")\n",
"\n",
"# create data and weight variables\n",
"data = relax.Var(\"data\", [n, m], relax.DynTensorType(2, \"float32\"))\n",
"weight = relax.Var(\"weight\", [m, n], relax.DynTensorType(2, \"float32\"))\n",
"\n",
"# construct a mlp model\n",
"mod = build_mlp(data, weight)\n",
"R.parser.pretty_print(mod)\n",
"\n",
"# build and create vm executor\n",
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8102fabf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8.095244 7.7235837 7.2910867 8.443724 8.958878 9.012167\n",
" 6.2120357 10.1474695 8.469089 7.993631 7.2318916 7.1488924\n",
" 9.181463 9.888946 7.4828773 7.393664 ]\n",
" [ 6.051395 7.281189 7.8760085 8.753442 6.9901752 7.286955\n",
" 5.8768206 9.506951 7.4303255 7.2454524 6.817876 6.3926682\n",
" 7.413857 8.555092 7.938551 6.3793116]\n",
" [ 7.050069 8.146633 7.577522 9.216323 7.63874 8.516393\n",
" 6.0171905 9.377093 8.423015 7.169124 6.4793515 5.9576015\n",
" 8.156402 9.256693 7.9188857 7.347392 ]\n",
" [ 6.8375144 7.0806613 7.232546 6.8441987 6.581527 8.293472\n",
" 5.9637613 8.6074705 7.3677044 7.7873907 6.419702 6.1378307\n",
" 6.6358232 8.800402 7.43657 7.4201636]\n",
" [ 6.958605 8.005278 6.393196 8.5577755 7.973779 9.274346\n",
" 7.1687098 9.766083 7.716219 7.32963 7.0473433 6.7898445\n",
" 8.318464 8.446748 7.0219917 6.123752 ]\n",
" [ 7.7625294 7.0056577 7.2986197 7.4696107 6.741804 8.562313\n",
" 6.763641 9.269506 7.983401 6.6070294 6.7254267 6.6040297\n",
" 8.351605 9.130937 7.5254893 7.160763 ]\n",
" [ 7.156014 7.827703 8.12125 7.753513 8.445888 8.61097\n",
" 6.417519 8.823943 8.172567 7.9088287 6.9773703 6.460421\n",
" 7.808111 8.590609 8.104791 7.838943 ]\n",
" [ 6.3050046 7.45814 6.4376473 7.2996607 7.829011 8.281724\n",
" 5.924798 8.665511 7.338019 7.0315704 6.467948 5.5320587\n",
" 7.4525414 8.788908 6.5465045 6.375387 ]\n",
" [ 7.203224 6.936556 6.9815955 7.9408994 6.9658127 8.797333\n",
" 6.0595207 8.919445 7.4469514 7.177114 6.276354 7.216568\n",
" 8.093876 8.271761 7.5531983 6.2297573]\n",
" [ 7.9077554 8.393251 7.7091074 7.876217 9.050663 9.165333\n",
" 7.29719 10.020016 9.2527 8.996158 8.42391 6.766323\n",
" 8.467502 10.105639 7.681791 7.9979396]\n",
" [ 7.722364 7.502222 7.16021 7.320424 6.938996 8.806945\n",
" 6.0865664 8.607358 7.289708 7.509741 6.8619833 6.346017\n",
" 6.9190187 9.129631 7.799539 7.013547 ]\n",
" [ 6.0149918 6.927125 7.1344876 7.1830897 7.465643 7.3225746\n",
" 5.3835173 7.936848 6.925822 7.061758 6.644145 5.838758\n",
" 6.670918 7.5502677 7.6008825 6.3464193]\n",
" [ 7.827011 7.927222 8.246812 7.77723 7.6312366 9.066271\n",
" 7.001202 9.651561 8.513117 8.172654 7.569078 6.1455894\n",
" 8.610018 9.841372 7.924851 8.502279 ]\n",
" [ 7.1759543 7.4548283 8.092802 8.589085 6.612156 7.836399\n",
" 6.296878 9.386656 7.850534 7.32955 7.614632 7.229883\n",
" 7.568119 9.086787 7.3559017 7.266826 ]\n",
" [ 6.887543 7.3447785 7.154263 6.1262264 6.5849376 7.7508464\n",
" 5.4122353 7.795231 7.67428 6.178569 7.021037 5.7336097\n",
" 6.4718122 6.8764296 6.5011544 5.798081 ]\n",
" [ 6.4761295 7.764174 6.3661256 7.272473 7.787151 8.531428\n",
" 5.380861 8.1289835 7.5046635 7.178151 5.9490094 5.8559055\n",
" 7.414933 7.15778 6.905289 6.256718 ]]\n"
]
}
],
"source": [
"# run the mlp model on relax vm\n",
"data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))\n",
"weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))\n",
"res = vm[\"mlp\"](data, weight)\n",
"print(res)"
]
},
{
"cell_type": "markdown",
"id": "5bad13ca",
"metadata": {},
"source": [
"## Relay to Relax translator"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "58ecdee8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def conv2d_nchw11(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 256, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw11\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 256, 16, 16], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 16, 16):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 15 and 1 <= i3_2 and i3_2 < 15, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 256, 3, 3):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def dense(rxplaceholder_2: tir.Buffer[(1, 2048), \"float32\"], rxplaceholder_3: tir.Buffer[(1000, 2048), \"float32\"], T_matmul_NT_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"dense\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(1, 1000, 2048):\n",
" with tir.block(\"T_matmul_NT\"):\n",
" i, j, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n",
" tir.reads(rxplaceholder_2[i, k], rxplaceholder_3[j, k])\n",
" tir.writes(T_matmul_NT_1[i, j])\n",
" tir.block_attr({\"layout_free_placeholders\":[rxplaceholder_3]})\n",
" with tir.init():\n",
" T_matmul_NT_1[i, j] = tir.float32(0)\n",
" T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + rxplaceholder_2[i, k] * rxplaceholder_3[j, k]\n",
" \n",
" @tir.prim_func\n",
" def global_avg_pool2d(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"], tensor_2: tir.Buffer[(1, 2048, 1, 1), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"global_avg_pool2d\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" tensor_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3, i4, i5 in tir.grid(1, 2048, 1, 1, 7, 7):\n",
" with tir.block(\"tensor\"):\n",
" ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap(\"SSSSRR\", [i0, i1, i2, i3, i4, i5])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])\n",
" tir.writes(tensor_3[ax0, ax1, ax2, ax3])\n",
" with tir.init():\n",
" tensor_3[ax0, ax1, ax2, ax3] = tir.float32(0)\n",
" tensor_3[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"tensor_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(tensor_3[ax0, ax1, ax2, ax3])\n",
" tir.writes(tensor_2[ax0, ax1, ax2, ax3])\n",
" tensor_2[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] * tir.float32(0.020408163265306121)\n",
" \n",
" @tir.prim_func\n",
" def batch_norm1(rxplaceholder_5: tir.Buffer[(1, 64, 112, 112), \"float32\"], rxplaceholder_6: tir.Buffer[(64,), \"float32\"], rxplaceholder_7: tir.Buffer[(64,), \"float32\"], rxplaceholder_8: tir.Buffer[(64,), \"float32\"], rxplaceholder_9: tir.Buffer[(64,), \"float32\"], T_add_2: tir.Buffer[(1, 64, 112, 112), \"float32\"], T_multiply_3: tir.Buffer[(64,), \"float32\"], T_multiply_4: tir.Buffer[(64,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm1\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 64]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 64]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 112, 112):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 64]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 112, 112):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 64]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 112, 112):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(64):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(64, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(64):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(64, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def add3(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_add_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add3\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n",
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw4(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw4\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 256, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def add2(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_add_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n",
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw(rxplaceholder_2: tir.Buffer[(1, 3, 224, 224), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 3, 7, 7), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 112, 112), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 3, 230, 230], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 230, 230):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(3 <= i2_2 and i2_2 < 227 and 3 <= i3_2 and i3_2 < 227, rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3], tir.float32(0), dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 112, 112, 3, 7, 7):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw13(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(1024, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw13\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 512, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm7(rxplaceholder_5: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_6: tir.Buffer[(1024,), \"float32\"], rxplaceholder_7: tir.Buffer[(1024,), \"float32\"], rxplaceholder_8: tir.Buffer[(1024,), \"float32\"], rxplaceholder_9: tir.Buffer[(1024,), \"float32\"], T_add_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_multiply_3: tir.Buffer[(1024,), \"float32\"], T_multiply_4: tir.Buffer[(1024,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm7\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 1024])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 1024]\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 1024])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 1024]\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 1024])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 1024]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 1024, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 1024])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 1024]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(1024):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(1024, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(1024):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(1024, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm8(rxplaceholder_5: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_6: tir.Buffer[(512,), \"float32\"], rxplaceholder_7: tir.Buffer[(512,), \"float32\"], rxplaceholder_8: tir.Buffer[(512,), \"float32\"], rxplaceholder_9: tir.Buffer[(512,), \"float32\"], T_add_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], T_multiply_3: tir.Buffer[(512,), \"float32\"], T_multiply_4: tir.Buffer[(512,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm8\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 512]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 512]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 512]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 512]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(512):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(512, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(512):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(512, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def relu3(rxplaceholder_1: tir.Buffer[(1, 128, 28, 28), \"float32\"], T_relu_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu3\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def relu1(rxplaceholder_1: tir.Buffer[(1, 64, 56, 56), \"float32\"], T_relu_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @relax.function\n",
" def main(data: Tensor((1, 3, 224, 224), \"float32\"), bn_data_gamma: Tensor((3,), \"float32\"), bn_data_beta: Tensor((3,), \"float32\"), bn_data_moving_mean: Tensor((3,), \"float32\"), bn_data_moving_var: Tensor((3,), \"float32\"), conv0_weight: Tensor((64, 3, 7, 7), \"float32\"), bn0_gamma: Tensor((64,), \"float32\"), bn0_beta: Tensor((64,), \"float32\"), bn0_moving_mean: Tensor((64,), \"float32\"), bn0_moving_var: Tensor((64,), \"float32\"), stage1_unit1_bn1_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn1_beta: Tensor((64,), \"float32\"), stage1_unit1_bn1_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn1_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv1_weight: Tensor((64, 64, 1, 1), \"float32\"), stage1_unit1_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn2_beta: Tensor((64,), \"float32\"), stage1_unit1_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit1_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn3_beta: Tensor((64,), \"float32\"), stage1_unit1_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit1_sc_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit2_bn1_gamma: Tensor((256,), \"float32\"), stage1_unit2_bn1_beta: Tensor((256,), \"float32\"), stage1_unit2_bn1_moving_mean: Tensor((256,), \"float32\"), stage1_unit2_bn1_moving_var: Tensor((256,), \"float32\"), stage1_unit2_conv1_weight: Tensor((64, 256, 1, 1), \"float32\"), stage1_unit2_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit2_bn2_beta: Tensor((64,), \"float32\"), stage1_unit2_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit2_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit2_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit2_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit2_bn3_beta: Tensor((64,), \"float32\"), stage1_unit2_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit2_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit2_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit3_bn1_gamma: Tensor((256,), \"float32\"), stage1_unit3_bn1_beta: Tensor((256,), \"float32\"), stage1_unit3_bn1_moving_mean: Tensor((256,), \"float32\"), stage1_unit3_bn1_moving_var: Tensor((256,), \"float32\"), stage1_unit3_conv1_weight: Tensor((64, 256, 1, 1), \"float32\"), stage1_unit3_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit3_bn2_beta: Tensor((64,), \"float32\"), stage1_unit3_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit3_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit3_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit3_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit3_bn3_beta: Tensor((64,), \"float32\"), stage1_unit3_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit3_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit3_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage2_unit1_bn1_gamma: Tensor((256,), \"float32\"), stage2_unit1_bn1_beta: Tensor((256,), \"float32\"), stage2_unit1_bn1_moving_mean: Tensor((256,), \"float32\"), stage2_unit1_bn1_moving_var: Tensor((256,), \"float32\"), stage2_unit1_conv1_weight: Tensor((128, 256, 1, 1), \"float32\"), stage2_unit1_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit1_bn2_beta: Tensor((128,), \"float32\"), stage2_unit1_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit1_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit1_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit1_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit1_bn3_beta: Tensor((128,), \"float32\"), stage2_unit1_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit1_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit1_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit1_sc_weight: Tensor((512, 256, 1, 1), \"float32\"), stage2_unit2_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit2_bn1_beta: Tensor((512,), \"float32\"), stage2_unit2_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit2_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit2_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit2_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit2_bn2_beta: Tensor((128,), \"float32\"), stage2_unit2_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit2_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit2_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit2_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit2_bn3_beta: Tensor((128,), \"float32\"), stage2_unit2_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit2_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit2_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit3_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit3_bn1_beta: Tensor((512,), \"float32\"), stage2_unit3_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit3_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit3_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit3_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit3_bn2_beta: Tensor((128,), \"float32\"), stage2_unit3_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit3_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit3_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit3_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit3_bn3_beta: Tensor((128,), \"float32\"), stage2_unit3_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit3_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit3_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit4_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit4_bn1_beta: Tensor((512,), \"float32\"), stage2_unit4_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit4_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit4_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit4_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit4_bn2_beta: Tensor((128,), \"float32\"), stage2_unit4_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit4_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit4_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit4_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit4_bn3_beta: Tensor((128,), \"float32\"), stage2_unit4_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit4_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit4_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage3_unit1_bn1_gamma: Tensor((512,), \"float32\"), stage3_unit1_bn1_beta: Tensor((512,), \"float32\"), stage3_unit1_bn1_moving_mean: Tensor((512,), \"float32\"), stage3_unit1_bn1_moving_var: Tensor((512,), \"float32\"), stage3_unit1_conv1_weight: Tensor((256, 512, 1, 1), \"float32\"), stage3_unit1_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit1_bn2_beta: Tensor((256,), \"float32\"), stage3_unit1_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit1_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit1_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit1_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit1_bn3_beta: Tensor((256,), \"float32\"), stage3_unit1_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit1_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit1_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit1_sc_weight: Tensor((1024, 512, 1, 1), \"float32\"), stage3_unit2_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit2_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit2_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit2_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit2_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit2_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit2_bn2_beta: Tensor((256,), \"float32\"), stage3_unit2_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit2_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit2_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit2_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit2_bn3_beta: Tensor((256,), \"float32\"), stage3_unit2_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit2_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit2_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit3_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit3_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit3_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit3_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit3_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit3_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit3_bn2_beta: Tensor((256,), \"float32\"), stage3_unit3_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit3_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit3_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit3_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit3_bn3_beta: Tensor((256,), \"float32\"), stage3_unit3_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit3_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit3_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit4_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit4_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit4_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit4_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit4_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit4_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit4_bn2_beta: Tensor((256,), \"float32\"), stage3_unit4_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit4_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit4_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit4_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit4_bn3_beta: Tensor((256,), \"float32\"), stage3_unit4_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit4_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit4_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit5_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit5_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit5_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit5_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit5_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit5_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit5_bn2_beta: Tensor((256,), \"float32\"), stage3_unit5_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit5_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit5_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit5_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit5_bn3_beta: Tensor((256,), \"float32\"), stage3_unit5_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit5_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit5_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit6_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit6_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit6_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit6_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit6_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit6_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit6_bn2_beta: Tensor((256,), \"float32\"), stage3_unit6_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit6_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit6_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit6_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit6_bn3_beta: Tensor((256,), \"float32\"), stage3_unit6_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit6_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit6_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage4_unit1_bn1_gamma: Tensor((1024,), \"float32\"), stage4_unit1_bn1_beta: Tensor((1024,), \"float32\"), stage4_unit1_bn1_moving_mean: Tensor((1024,), \"float32\"), stage4_unit1_bn1_moving_var: Tensor((1024,), \"float32\"), stage4_unit1_conv1_weight: Tensor((512, 1024, 1, 1), \"float32\"), stage4_unit1_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit1_bn2_beta: Tensor((512,), \"float32\"), stage4_unit1_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit1_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit1_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit1_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit1_bn3_beta: Tensor((512,), \"float32\"), stage4_unit1_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit1_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit1_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), stage4_unit1_sc_weight: Tensor((2048, 1024, 1, 1), \"float32\"), stage4_unit2_bn1_gamma: Tensor((2048,), \"float32\"), stage4_unit2_bn1_beta: Tensor((2048,), \"float32\"), stage4_unit2_bn1_moving_mean: Tensor((2048,), \"float32\"), stage4_unit2_bn1_moving_var: Tensor((2048,), \"float32\"), stage4_unit2_conv1_weight: Tensor((512, 2048, 1, 1), \"float32\"), stage4_unit2_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit2_bn2_beta: Tensor((512,), \"float32\"), stage4_unit2_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit2_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit2_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit2_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit2_bn3_beta: Tensor((512,), \"float32\"), stage4_unit2_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit2_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit2_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), stage4_unit3_bn1_gamma: Tensor((2048,), \"float32\"), stage4_unit3_bn1_beta: Tensor((2048,), \"float32\"), stage4_unit3_bn1_moving_mean: Tensor((2048,), \"float32\"), stage4_unit3_bn1_moving_var: Tensor((2048,), \"float32\"), stage4_unit3_conv1_weight: Tensor((512, 2048, 1, 1), \"float32\"), stage4_unit3_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit3_bn2_beta: Tensor((512,), \"float32\"), stage4_unit3_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit3_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit3_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit3_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit3_bn3_beta: Tensor((512,), \"float32\"), stage4_unit3_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit3_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit3_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), bn1_gamma: Tensor((2048,), \"float32\"), bn1_beta: Tensor((2048,), \"float32\"), bn1_moving_mean: Tensor((2048,), \"float32\"), bn1_moving_var: Tensor((2048,), \"float32\"), fc1_weight: Tensor((1000, 2048), \"float32\"), fc1_bias: Tensor((1000,), \"float32\")) -> Tensor(_, \"float32\"):\n",
" # block 0\n",
" gv = relax.call_tir(batch_norm, (data, bn_data_gamma, bn_data_beta, bn_data_moving_mean, bn_data_moving_var), ((1, 3, 224, 224), (3,), (3,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv1: Tensor((1, 3, 224, 224), \"float32\") = gv[0]\n",
" gv2 = relax.call_tir(conv2d_nchw, (gv1, conv0_weight), (1, 64, 112, 112), dtype=\"float32\")\n",
" gv3 = relax.call_tir(batch_norm1, (gv2, bn0_gamma, bn0_beta, bn0_moving_mean, bn0_moving_var), ((1, 64, 112, 112), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv4: Tensor((1, 64, 112, 112), \"float32\") = gv3[0]\n",
" gv5 = relax.call_tir(relu, (gv4,), (1, 64, 112, 112), dtype=\"float32\")\n",
" gv6 = relax.call_tir(max_pool2d, (gv5,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv7 = relax.call_tir(batch_norm2, (gv6, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta, stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv8: Tensor((1, 64, 56, 56), \"float32\") = gv7[0]\n",
" gv9 = relax.call_tir(relu1, (gv8,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv10 = relax.call_tir(conv2d_nchw1, (gv9, stage1_unit1_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv11 = relax.call_tir(batch_norm2, (gv10, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta, stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv12: Tensor((1, 64, 56, 56), \"float32\") = gv11[0]\n",
" gv13 = relax.call_tir(relu1, (gv12,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv14 = relax.call_tir(conv2d_nchw2, (gv13, stage1_unit1_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv15 = relax.call_tir(batch_norm2, (gv14, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta, stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv16: Tensor((1, 64, 56, 56), \"float32\") = gv15[0]\n",
" gv17 = relax.call_tir(relu1, (gv16,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv18 = relax.call_tir(conv2d_nchw3, (gv17, stage1_unit1_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv19 = relax.call_tir(conv2d_nchw3, (gv9, stage1_unit1_sc_weight), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv20 = relax.call_tir(add, (gv18, gv19), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv21 = relax.call_tir(batch_norm3, (gv20, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta, stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv22: Tensor((1, 256, 56, 56), \"float32\") = gv21[0]\n",
" gv23 = relax.call_tir(relu2, (gv22,), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv24 = relax.call_tir(conv2d_nchw4, (gv23, stage1_unit2_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv25 = relax.call_tir(batch_norm2, (gv24, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta, stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv26: Tensor((1, 64, 56, 56), \"float32\") = gv25[0]\n",
" gv27 = relax.call_tir(relu1, (gv26,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv28 = relax.call_tir(conv2d_nchw2, (gv27, stage1_unit2_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv29 = relax.call_tir(batch_norm2, (gv28, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta, stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv30: Tensor((1, 64, 56, 56), \"float32\") = gv29[0]\n",
" gv31 = relax.call_tir(relu1, (gv30,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv32 = relax.call_tir(conv2d_nchw3, (gv31, stage1_unit2_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv33 = relax.call_tir(add, (gv32, gv20), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv34 = relax.call_tir(batch_norm3, (gv33, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta, stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv35: Tensor((1, 256, 56, 56), \"float32\") = gv34[0]\n",
" gv36 = relax.call_tir(relu2, (gv35,), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv37 = relax.call_tir(conv2d_nchw4, (gv36, stage1_unit3_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv38 = relax.call_tir(batch_norm2, (gv37, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta, stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv39: Tensor((1, 64, 56, 56), \"float32\") = gv38[0]\n",
" gv40 = relax.call_tir(relu1, (gv39,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv41 = relax.call_tir(conv2d_nchw2, (gv40, stage1_unit3_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv42 = relax.call_tir(batch_norm2, (gv41, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta, stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv43: Tensor((1, 64, 56, 56), \"float32\") = gv42[0]\n",
" gv44 = relax.call_tir(relu1, (gv43,), (1, 64, 56, 56), dtype=\"float32\")\n",
" gv45 = relax.call_tir(conv2d_nchw3, (gv44, stage1_unit3_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv46 = relax.call_tir(add, (gv45, gv33), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv47 = relax.call_tir(batch_norm3, (gv46, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta, stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv48: Tensor((1, 256, 56, 56), \"float32\") = gv47[0]\n",
" gv49 = relax.call_tir(relu2, (gv48,), (1, 256, 56, 56), dtype=\"float32\")\n",
" gv50 = relax.call_tir(conv2d_nchw5, (gv49, stage2_unit1_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv51 = relax.call_tir(batch_norm4, (gv50, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta, stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv52: Tensor((1, 128, 28, 28), \"float32\") = gv51[0]\n",
" gv53 = relax.call_tir(relu3, (gv52,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv54 = relax.call_tir(conv2d_nchw6, (gv53, stage2_unit1_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv55 = relax.call_tir(batch_norm4, (gv54, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta, stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv56: Tensor((1, 128, 28, 28), \"float32\") = gv55[0]\n",
" gv57 = relax.call_tir(relu3, (gv56,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv58 = relax.call_tir(conv2d_nchw7, (gv57, stage2_unit1_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv59 = relax.call_tir(conv2d_nchw8, (gv49, stage2_unit1_sc_weight), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv60 = relax.call_tir(add1, (gv58, gv59), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv61 = relax.call_tir(batch_norm5, (gv60, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta, stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv62: Tensor((1, 512, 28, 28), \"float32\") = gv61[0]\n",
" gv63 = relax.call_tir(relu4, (gv62,), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv64 = relax.call_tir(conv2d_nchw9, (gv63, stage2_unit2_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv65 = relax.call_tir(batch_norm4, (gv64, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta, stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv66: Tensor((1, 128, 28, 28), \"float32\") = gv65[0]\n",
" gv67 = relax.call_tir(relu3, (gv66,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv68 = relax.call_tir(conv2d_nchw6, (gv67, stage2_unit2_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv69 = relax.call_tir(batch_norm4, (gv68, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta, stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv70: Tensor((1, 128, 28, 28), \"float32\") = gv69[0]\n",
" gv71 = relax.call_tir(relu3, (gv70,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv72 = relax.call_tir(conv2d_nchw7, (gv71, stage2_unit2_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv73 = relax.call_tir(add1, (gv72, gv60), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv74 = relax.call_tir(batch_norm5, (gv73, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta, stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv75: Tensor((1, 512, 28, 28), \"float32\") = gv74[0]\n",
" gv76 = relax.call_tir(relu4, (gv75,), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv77 = relax.call_tir(conv2d_nchw9, (gv76, stage2_unit3_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv78 = relax.call_tir(batch_norm4, (gv77, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta, stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv79: Tensor((1, 128, 28, 28), \"float32\") = gv78[0]\n",
" gv80 = relax.call_tir(relu3, (gv79,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv81 = relax.call_tir(conv2d_nchw6, (gv80, stage2_unit3_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv82 = relax.call_tir(batch_norm4, (gv81, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta, stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv83: Tensor((1, 128, 28, 28), \"float32\") = gv82[0]\n",
" gv84 = relax.call_tir(relu3, (gv83,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv85 = relax.call_tir(conv2d_nchw7, (gv84, stage2_unit3_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv86 = relax.call_tir(add1, (gv85, gv73), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv87 = relax.call_tir(batch_norm5, (gv86, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta, stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv88: Tensor((1, 512, 28, 28), \"float32\") = gv87[0]\n",
" gv89 = relax.call_tir(relu4, (gv88,), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv90 = relax.call_tir(conv2d_nchw9, (gv89, stage2_unit4_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv91 = relax.call_tir(batch_norm4, (gv90, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta, stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv92: Tensor((1, 128, 28, 28), \"float32\") = gv91[0]\n",
" gv93 = relax.call_tir(relu3, (gv92,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv94 = relax.call_tir(conv2d_nchw6, (gv93, stage2_unit4_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv95 = relax.call_tir(batch_norm4, (gv94, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta, stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv96: Tensor((1, 128, 28, 28), \"float32\") = gv95[0]\n",
" gv97 = relax.call_tir(relu3, (gv96,), (1, 128, 28, 28), dtype=\"float32\")\n",
" gv98 = relax.call_tir(conv2d_nchw7, (gv97, stage2_unit4_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv99 = relax.call_tir(add1, (gv98, gv86), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv100 = relax.call_tir(batch_norm5, (gv99, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta, stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv101: Tensor((1, 512, 28, 28), \"float32\") = gv100[0]\n",
" gv102 = relax.call_tir(relu4, (gv101,), (1, 512, 28, 28), dtype=\"float32\")\n",
" gv103 = relax.call_tir(conv2d_nchw10, (gv102, stage3_unit1_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv104 = relax.call_tir(batch_norm6, (gv103, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta, stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv105: Tensor((1, 256, 14, 14), \"float32\") = gv104[0]\n",
" gv106 = relax.call_tir(relu5, (gv105,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv107 = relax.call_tir(conv2d_nchw11, (gv106, stage3_unit1_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv108 = relax.call_tir(batch_norm6, (gv107, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta, stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv109: Tensor((1, 256, 14, 14), \"float32\") = gv108[0]\n",
" gv110 = relax.call_tir(relu5, (gv109,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv111 = relax.call_tir(conv2d_nchw12, (gv110, stage3_unit1_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv112 = relax.call_tir(conv2d_nchw13, (gv102, stage3_unit1_sc_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv113 = relax.call_tir(add2, (gv111, gv112), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv114 = relax.call_tir(batch_norm7, (gv113, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta, stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv115: Tensor((1, 1024, 14, 14), \"float32\") = gv114[0]\n",
" gv116 = relax.call_tir(relu6, (gv115,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv117 = relax.call_tir(conv2d_nchw14, (gv116, stage3_unit2_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv118 = relax.call_tir(batch_norm6, (gv117, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta, stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv119: Tensor((1, 256, 14, 14), \"float32\") = gv118[0]\n",
" gv120 = relax.call_tir(relu5, (gv119,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv121 = relax.call_tir(conv2d_nchw11, (gv120, stage3_unit2_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv122 = relax.call_tir(batch_norm6, (gv121, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta, stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv123: Tensor((1, 256, 14, 14), \"float32\") = gv122[0]\n",
" gv124 = relax.call_tir(relu5, (gv123,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv125 = relax.call_tir(conv2d_nchw12, (gv124, stage3_unit2_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv126 = relax.call_tir(add2, (gv125, gv113), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv127 = relax.call_tir(batch_norm7, (gv126, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta, stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv128: Tensor((1, 1024, 14, 14), \"float32\") = gv127[0]\n",
" gv129 = relax.call_tir(relu6, (gv128,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv130 = relax.call_tir(conv2d_nchw14, (gv129, stage3_unit3_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv131 = relax.call_tir(batch_norm6, (gv130, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta, stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv132: Tensor((1, 256, 14, 14), \"float32\") = gv131[0]\n",
" gv133 = relax.call_tir(relu5, (gv132,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv134 = relax.call_tir(conv2d_nchw11, (gv133, stage3_unit3_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv135 = relax.call_tir(batch_norm6, (gv134, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta, stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv136: Tensor((1, 256, 14, 14), \"float32\") = gv135[0]\n",
" gv137 = relax.call_tir(relu5, (gv136,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv138 = relax.call_tir(conv2d_nchw12, (gv137, stage3_unit3_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv139 = relax.call_tir(add2, (gv138, gv126), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv140 = relax.call_tir(batch_norm7, (gv139, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta, stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv141: Tensor((1, 1024, 14, 14), \"float32\") = gv140[0]\n",
" gv142 = relax.call_tir(relu6, (gv141,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv143 = relax.call_tir(conv2d_nchw14, (gv142, stage3_unit4_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv144 = relax.call_tir(batch_norm6, (gv143, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta, stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv145: Tensor((1, 256, 14, 14), \"float32\") = gv144[0]\n",
" gv146 = relax.call_tir(relu5, (gv145,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv147 = relax.call_tir(conv2d_nchw11, (gv146, stage3_unit4_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv148 = relax.call_tir(batch_norm6, (gv147, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta, stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv149: Tensor((1, 256, 14, 14), \"float32\") = gv148[0]\n",
" gv150 = relax.call_tir(relu5, (gv149,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv151 = relax.call_tir(conv2d_nchw12, (gv150, stage3_unit4_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv152 = relax.call_tir(add2, (gv151, gv139), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv153 = relax.call_tir(batch_norm7, (gv152, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta, stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv154: Tensor((1, 1024, 14, 14), \"float32\") = gv153[0]\n",
" gv155 = relax.call_tir(relu6, (gv154,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv156 = relax.call_tir(conv2d_nchw14, (gv155, stage3_unit5_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv157 = relax.call_tir(batch_norm6, (gv156, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta, stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv158: Tensor((1, 256, 14, 14), \"float32\") = gv157[0]\n",
" gv159 = relax.call_tir(relu5, (gv158,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv160 = relax.call_tir(conv2d_nchw11, (gv159, stage3_unit5_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv161 = relax.call_tir(batch_norm6, (gv160, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta, stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv162: Tensor((1, 256, 14, 14), \"float32\") = gv161[0]\n",
" gv163 = relax.call_tir(relu5, (gv162,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv164 = relax.call_tir(conv2d_nchw12, (gv163, stage3_unit5_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv165 = relax.call_tir(add2, (gv164, gv152), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv166 = relax.call_tir(batch_norm7, (gv165, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta, stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv167: Tensor((1, 1024, 14, 14), \"float32\") = gv166[0]\n",
" gv168 = relax.call_tir(relu6, (gv167,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv169 = relax.call_tir(conv2d_nchw14, (gv168, stage3_unit6_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv170 = relax.call_tir(batch_norm6, (gv169, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta, stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv171: Tensor((1, 256, 14, 14), \"float32\") = gv170[0]\n",
" gv172 = relax.call_tir(relu5, (gv171,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv173 = relax.call_tir(conv2d_nchw11, (gv172, stage3_unit6_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv174 = relax.call_tir(batch_norm6, (gv173, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta, stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv175: Tensor((1, 256, 14, 14), \"float32\") = gv174[0]\n",
" gv176 = relax.call_tir(relu5, (gv175,), (1, 256, 14, 14), dtype=\"float32\")\n",
" gv177 = relax.call_tir(conv2d_nchw12, (gv176, stage3_unit6_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv178 = relax.call_tir(add2, (gv177, gv165), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv179 = relax.call_tir(batch_norm7, (gv178, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta, stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv180: Tensor((1, 1024, 14, 14), \"float32\") = gv179[0]\n",
" gv181 = relax.call_tir(relu6, (gv180,), (1, 1024, 14, 14), dtype=\"float32\")\n",
" gv182 = relax.call_tir(conv2d_nchw15, (gv181, stage4_unit1_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv183 = relax.call_tir(batch_norm8, (gv182, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta, stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv184: Tensor((1, 512, 7, 7), \"float32\") = gv183[0]\n",
" gv185 = relax.call_tir(relu7, (gv184,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv186 = relax.call_tir(conv2d_nchw16, (gv185, stage4_unit1_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv187 = relax.call_tir(batch_norm8, (gv186, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta, stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv188: Tensor((1, 512, 7, 7), \"float32\") = gv187[0]\n",
" gv189 = relax.call_tir(relu7, (gv188,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv190 = relax.call_tir(conv2d_nchw17, (gv189, stage4_unit1_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv191 = relax.call_tir(conv2d_nchw18, (gv181, stage4_unit1_sc_weight), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv192 = relax.call_tir(add3, (gv190, gv191), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv193 = relax.call_tir(batch_norm9, (gv192, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta, stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv194: Tensor((1, 2048, 7, 7), \"float32\") = gv193[0]\n",
" gv195 = relax.call_tir(relu8, (gv194,), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv196 = relax.call_tir(conv2d_nchw19, (gv195, stage4_unit2_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv197 = relax.call_tir(batch_norm8, (gv196, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta, stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv198: Tensor((1, 512, 7, 7), \"float32\") = gv197[0]\n",
" gv199 = relax.call_tir(relu7, (gv198,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv200 = relax.call_tir(conv2d_nchw16, (gv199, stage4_unit2_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv201 = relax.call_tir(batch_norm8, (gv200, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta, stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv202: Tensor((1, 512, 7, 7), \"float32\") = gv201[0]\n",
" gv203 = relax.call_tir(relu7, (gv202,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv204 = relax.call_tir(conv2d_nchw17, (gv203, stage4_unit2_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv205 = relax.call_tir(add3, (gv204, gv192), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv206 = relax.call_tir(batch_norm9, (gv205, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta, stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv207: Tensor((1, 2048, 7, 7), \"float32\") = gv206[0]\n",
" gv208 = relax.call_tir(relu8, (gv207,), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv209 = relax.call_tir(conv2d_nchw19, (gv208, stage4_unit3_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv210 = relax.call_tir(batch_norm8, (gv209, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta, stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv211: Tensor((1, 512, 7, 7), \"float32\") = gv210[0]\n",
" gv212 = relax.call_tir(relu7, (gv211,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv213 = relax.call_tir(conv2d_nchw16, (gv212, stage4_unit3_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv214 = relax.call_tir(batch_norm8, (gv213, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta, stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv215: Tensor((1, 512, 7, 7), \"float32\") = gv214[0]\n",
" gv216 = relax.call_tir(relu7, (gv215,), (1, 512, 7, 7), dtype=\"float32\")\n",
" gv217 = relax.call_tir(conv2d_nchw17, (gv216, stage4_unit3_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv218 = relax.call_tir(add3, (gv217, gv205), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv219 = relax.call_tir(batch_norm9, (gv218, bn1_gamma, bn1_beta, bn1_moving_mean, bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n",
" gv220: Tensor((1, 2048, 7, 7), \"float32\") = gv219[0]\n",
" gv221 = relax.call_tir(relu8, (gv220,), (1, 2048, 7, 7), dtype=\"float32\")\n",
" gv222 = relax.call_tir(global_avg_pool2d, (gv221,), (1, 2048, 1, 1), dtype=\"float32\")\n",
" gv223 = relax.call_tir(batch_flatten, (gv222,), (1, 2048), dtype=\"float32\")\n",
" gv224 = relax.call_tir(dense, (gv223, fc1_weight), (1, 1000), dtype=\"float32\")\n",
" gv225 = relax.call_tir(bias_add, (gv224, fc1_bias), (1, 1000), dtype=\"float32\")\n",
" gv226 = relax.call_tir(softmax, (gv225,), (1, 1000), dtype=\"float32\")\n",
" return gv226\n",
" \n",
" @tir.prim_func\n",
" def add1(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_add_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n",
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw3(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 64, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw3\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 56, 56, 64, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw19(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 2048, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw19\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 2048, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def softmax(rxplaceholder_1: tir.Buffer[(1, 1000), \"float32\"], T_softmax_norm_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"softmax\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_softmax_maxelem_1 = tir.alloc_buffer([1], dtype=\"float32\")\n",
" T_softmax_exp_1 = tir.alloc_buffer([1, 1000], dtype=\"float32\")\n",
" T_softmax_expsum_1 = tir.alloc_buffer([1], dtype=\"float32\")\n",
" for i0_7, i1_3 in tir.grid(1, 1000):\n",
" with tir.block(\"T_softmax_maxelem\"):\n",
" i0_8, k = tir.axis.remap(\"SR\", [i0_7, i1_3])\n",
" tir.reads(rxplaceholder_1[i0_8, k])\n",
" tir.writes(T_softmax_maxelem_1[i0_8])\n",
" with tir.init():\n",
" T_softmax_maxelem_1[i0_8] = tir.float32(-3.4028234663852886e+38)\n",
" T_softmax_maxelem_1[i0_8] = tir.max(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])\n",
" for i0_9, i1_4 in tir.grid(1, 1000):\n",
" with tir.block(\"T_softmax_exp\"):\n",
" i0_10, i1_5 = tir.axis.remap(\"SS\", [i0_9, i1_4])\n",
" tir.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])\n",
" tir.writes(T_softmax_exp_1[i0_10, i1_5])\n",
" T_softmax_exp_1[i0_10, i1_5] = tir.exp(rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype=\"float32\")\n",
" for i0_11, i1_6 in tir.grid(1, 1000):\n",
" with tir.block(\"T_softmax_expsum\"):\n",
" i0_12, k = tir.axis.remap(\"SR\", [i0_11, i1_6])\n",
" tir.reads(T_softmax_exp_1[i0_12, k])\n",
" tir.writes(T_softmax_expsum_1[i0_12])\n",
" with tir.init():\n",
" T_softmax_expsum_1[i0_12] = tir.float32(0)\n",
" T_softmax_expsum_1[i0_12] = T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]\n",
" for i0_13, i1_7 in tir.grid(1, 1000):\n",
" with tir.block(\"T_softmax_norm\"):\n",
" i0_14, i1_8 = tir.axis.remap(\"SS\", [i0_13, i1_7])\n",
" tir.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])\n",
" tir.writes(T_softmax_norm_1[i0_14, i1_8])\n",
" tir.block_attr({\"axis\":1})\n",
" T_softmax_norm_1[i0_14, i1_8] = T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm9(rxplaceholder_5: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_6: tir.Buffer[(2048,), \"float32\"], rxplaceholder_7: tir.Buffer[(2048,), \"float32\"], rxplaceholder_8: tir.Buffer[(2048,), \"float32\"], rxplaceholder_9: tir.Buffer[(2048,), \"float32\"], T_add_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_multiply_3: tir.Buffer[(2048,), \"float32\"], T_multiply_4: tir.Buffer[(2048,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm9\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 2048])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 2048]\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 2048])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 2048]\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 2048])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 2048]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 2048, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 2048])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 2048]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(2048):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(2048, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(2048):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(2048, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm6(rxplaceholder_5: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_6: tir.Buffer[(256,), \"float32\"], rxplaceholder_7: tir.Buffer[(256,), \"float32\"], rxplaceholder_8: tir.Buffer[(256,), \"float32\"], rxplaceholder_9: tir.Buffer[(256,), \"float32\"], T_add_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], T_multiply_3: tir.Buffer[(256,), \"float32\"], T_multiply_4: tir.Buffer[(256,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm6\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 256]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 256]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 256]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 256]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(256):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(256, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(256):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(256, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw7(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 128, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw7\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 128, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def max_pool2d(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), \"float32\"], tensor_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"max_pool2d\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 64, 114, 114], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 114, 114):\n",
" with tir.block(\"pad_temp\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])\n",
" tir.writes(pad_temp_1[ax0, ax1, ax2, ax3])\n",
" pad_temp_1[ax0, ax1, ax2, ax3] = tir.if_then_else(1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], tir.float32(-3.4028234663852886e+38), dtype=\"float32\")\n",
" for i0, i1, i2, i3, i4, i5 in tir.grid(1, 64, 56, 56, 3, 3):\n",
" with tir.block(\"tensor\"):\n",
" ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap(\"SSSSRR\", [i0, i1, i2, i3, i4, i5])\n",
" tir.reads(pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])\n",
" tir.writes(tensor_1[ax0, ax1, ax2, ax3])\n",
" with tir.init():\n",
" tensor_1[ax0, ax1, ax2, ax3] = tir.float32(-3.4028234663852886e+38)\n",
" tensor_1[ax0, ax1, ax2, ax3] = tir.max(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw6(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 128, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw6\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 128, 30, 30], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 30, 30):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 29 and 1 <= i3_2 and i3_2 < 29, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 128, 3, 3):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw1(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 64, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw1\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm5(rxplaceholder_5: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_6: tir.Buffer[(512,), \"float32\"], rxplaceholder_7: tir.Buffer[(512,), \"float32\"], rxplaceholder_8: tir.Buffer[(512,), \"float32\"], rxplaceholder_9: tir.Buffer[(512,), \"float32\"], T_add_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_multiply_3: tir.Buffer[(512,), \"float32\"], T_multiply_4: tir.Buffer[(512,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm5\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 512]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 512]\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 512]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 512]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(512):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(512, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(512):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(512, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm3(rxplaceholder_5: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_6: tir.Buffer[(256,), \"float32\"], rxplaceholder_7: tir.Buffer[(256,), \"float32\"], rxplaceholder_8: tir.Buffer[(256,), \"float32\"], rxplaceholder_9: tir.Buffer[(256,), \"float32\"], T_add_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_multiply_3: tir.Buffer[(256,), \"float32\"], T_multiply_4: tir.Buffer[(256,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm3\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 256]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 256]\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 256]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 256]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(256):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(256, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(256):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(256, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def relu(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), \"float32\"], T_relu_1: tir.Buffer[(1, 64, 112, 112), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw10(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw10\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 512, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw5(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw5\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 256, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def relu5(rxplaceholder_1: tir.Buffer[(1, 256, 14, 14), \"float32\"], T_relu_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu5\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw15(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw15\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 1024, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw9(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw9\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 512, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw17(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(2048, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw17\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 512, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm(rxplaceholder_5: tir.Buffer[(1, 3, 224, 224), \"float32\"], rxplaceholder_6: tir.Buffer[(3,), \"float32\"], rxplaceholder_7: tir.Buffer[(3,), \"float32\"], rxplaceholder_8: tir.Buffer[(3,), \"float32\"], rxplaceholder_9: tir.Buffer[(3,), \"float32\"], T_add_2: tir.Buffer[(1, 3, 224, 224), \"float32\"], T_multiply_2: tir.Buffer[(3,), \"float32\"], T_multiply_3: tir.Buffer[(3,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_3 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 3, 224, 224], dtype=\"float32\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 3, 224, 224], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 3])\n",
" tir.writes(T_reshape_3[ax0, ax1, ax2, ax3])\n",
" T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 3]\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 224, 224):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_3[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_3[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 3])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 3]\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_4[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 224, 224):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 3, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 3])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 3]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 3, 224, 224):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_5[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T_reshape_5[ax0, ax1, 0, 0]\n",
" for i0_6 in tir.serial(3):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0 = tir.axis.spatial(3, i0_6)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_2[ax0])\n",
" T_multiply_2[ax0] = rxplaceholder_8[ax0]\n",
" for i0_7 in tir.serial(3):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(3, i0_7)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw14(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw14\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 1024, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def relu8(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_relu_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu8\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def relu4(rxplaceholder_1: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_relu_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu4\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def batch_flatten(rxplaceholder_1: tir.Buffer[(1, 2048, 1, 1), \"float32\"], tensor_1: tir.Buffer[(1, 2048), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_flatten\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(1, 2048):\n",
" with tir.block(\"tensor\"):\n",
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_1[ax0, ax1 % 2048, 0, 0])\n",
" tir.writes(tensor_1[ax0, ax1])\n",
" tensor_1[ax0, ax1] = rxplaceholder_1[ax0, ax1 % 2048, 0, 0]\n",
" \n",
" @tir.prim_func\n",
" def relu6(rxplaceholder_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_relu_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu6\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def bias_add(rxplaceholder_2: tir.Buffer[(1, 1000), \"float32\"], rxplaceholder_3: tir.Buffer[(1000,), \"float32\"], T_add_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"bias_add\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(1, 1000):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]\n",
" \n",
" @tir.prim_func\n",
" def relu2(rxplaceholder_1: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_relu_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu2\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw12(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(1024, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw12\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 256, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw16(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 512, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw16\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 512, 9, 9], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 9, 9):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 8 and 1 <= i3_2 and i3_2 < 8, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 512, 3, 3):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw8(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw8\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 256, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm4(rxplaceholder_5: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_6: tir.Buffer[(128,), \"float32\"], rxplaceholder_7: tir.Buffer[(128,), \"float32\"], rxplaceholder_8: tir.Buffer[(128,), \"float32\"], rxplaceholder_9: tir.Buffer[(128,), \"float32\"], T_add_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], T_multiply_3: tir.Buffer[(128,), \"float32\"], T_multiply_4: tir.Buffer[(128,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm4\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 128])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 128]\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 128])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 128]\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 128])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 128]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 128, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 128])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 128]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 128, 28, 28):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(128):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(128, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(128):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(128, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def add(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_add_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n",
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n",
" \n",
" @tir.prim_func\n",
" def batch_norm2(rxplaceholder_5: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_6: tir.Buffer[(64,), \"float32\"], rxplaceholder_7: tir.Buffer[(64,), \"float32\"], rxplaceholder_8: tir.Buffer[(64,), \"float32\"], rxplaceholder_9: tir.Buffer[(64,), \"float32\"], T_add_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], T_multiply_3: tir.Buffer[(64,), \"float32\"], T_multiply_4: tir.Buffer[(64,), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"batch_norm2\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_subtract_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n",
" T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_divide_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n",
" T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" T_multiply_5 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n",
" T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n",
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 64]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"T_subtract\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n",
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n",
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 64]\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n",
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n",
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"T_divide\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n",
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n",
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n",
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n",
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_2\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n",
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n",
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 64]\n",
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"T_multiply\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n",
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n",
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n",
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n",
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):\n",
" with tir.block(\"T_reshape_3\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n",
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])\n",
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n",
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 64]\n",
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 56, 56):\n",
" with tir.block(\"T_add_1\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n",
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n",
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n",
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n",
" for i0_8 in tir.serial(64):\n",
" with tir.block(\"T_multiply_1\"):\n",
" ax0 = tir.axis.spatial(64, i0_8)\n",
" tir.reads(rxplaceholder_8[ax0])\n",
" tir.writes(T_multiply_3[ax0])\n",
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n",
" for i0_9 in tir.serial(64):\n",
" with tir.block(\"T_multiply_2\"):\n",
" ax0 = tir.axis.spatial(64, i0_9)\n",
" tir.reads(rxplaceholder_9[ax0])\n",
" tir.writes(T_multiply_4[ax0])\n",
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw2(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 64, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw2\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 64, 58, 58], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 64, 58, 58):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 57 and 1 <= i3_2 and i3_2 < 57, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 3, 3):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n",
" @tir.prim_func\n",
" def relu7(rxplaceholder_1: tir.Buffer[(1, 512, 7, 7), \"float32\"], T_relu_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu7\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n",
" with tir.block(\"T_relu\"):\n",
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n",
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n",
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def conv2d_nchw18(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(2048, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"conv2d_nchw18\", \"tir.noalias\": True})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n",
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n",
" with tir.block(\"pad_temp\"):\n",
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n",
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n",
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n",
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n",
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 1024, 1, 1):\n",
" with tir.block(\"conv2d_nchw\"):\n",
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n",
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n",
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n",
" with tir.init():\n",
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n",
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n",
" \n"
]
}
],
"source": [
"from tvm import relay\n",
"import tvm.relay.testing\n",
"from tvm.relax.testing import relay_translator\n",
"\n",
"relay_mod, _ = relay.testing.resnet.get_workload(num_layers=50, batch_size=1, dtype=\"float32\")\n",
"\n",
"# translate the ResNet model from Relay to Relax\n",
"relax_mod = relay_translator.from_relay(relay_mod[\"main\"])\n",
"\n",
"# print the ResNet IRmodule got translated\n",
"R.parser.pretty_print(relax_mod)"
]
},
{
"cell_type": "markdown",
"id": "df9a585e",
"metadata": {},
"source": [
"## Tuning with MetaSchedule"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a8c19a40",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[09:22:10] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:96: Initializing Task #0: \"tir_matmul\"\n",
"[09:22:10] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:102: \n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, j0, k0 in tir.grid(32, 32, 32):\n",
" with tir.block():\n",
" i, j, k = tir.axis.remap(\"SSR\", [i0, j0, k0])\n",
" tir.reads(A_1[i, k], B_1[j, k])\n",
" tir.writes(C_1[i, j])\n",
" with tir.init():\n",
" C_1[i, j] = tir.float32(0)\n",
" C_1[i, j] = C_1[i, j] + A_1[i, k] * B_1[j, k]\n",
" \n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:106: Total 3 design space(s) generated\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #0:\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n",
" # body\n",
" with tir.block(\"root\"):\n",
" tir.reads()\n",
" tir.writes()\n",
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":64, \"meta_schedule.vectorize\":64})\n",
" C_global_1 = tir.alloc_buffer([32, 32], dtype=\"float32\")\n",
" for i0_0, j0_0, i0_1, j0_1 in tir.grid(1, 2, 4, 2):\n",
" for k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(4, 8, 1, 8, 1, 8):\n",
" with tir.block():\n",
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n",
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n",
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n",
" tir.reads(A_1[i, k], B_1[j, k])\n",
" tir.writes(C_global_1[i, j])\n",
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n",
" with tir.init():\n",
" C_global_1[i, j] = tir.float32(0)\n",
" C_global_1[i, j] = C_global_1[i, j] + A_1[i, k] * B_1[j, k]\n",
" for ax0, ax1 in tir.grid(8, 8):\n",
" with tir.block(\"C_global\"):\n",
" v0 = tir.axis.spatial(32, i0_1 * 8 + ax0)\n",
" v1 = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + ax1)\n",
" tir.reads(C_global_1[v0, v1])\n",
" tir.writes(C_1[v0, v1])\n",
" C_1[v0, v1] = C_global_1[v0, v1]\n",
" \n",
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n",
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n",
"l2, l3, l4 = sch.get_loops(block=b0)\n",
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n",
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n",
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n",
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n",
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n",
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n",
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n",
"b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope=\"global\")\n",
"sch.reverse_compute_at(block=b25, loop=l18, preserve_unit_loops=True)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n",
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v26)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #1:\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n",
" # body\n",
" with tir.block(\"root\"):\n",
" tir.reads()\n",
" tir.writes()\n",
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":0, \"meta_schedule.vectorize\":64})\n",
" C_global_1 = tir.alloc_buffer([32, 32], dtype=\"float32\")\n",
" for i0_0, j0_0 in tir.grid(1, 2):\n",
" for i0_1, j0_1, k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(4, 2, 4, 8, 1, 8, 1, 8):\n",
" with tir.block():\n",
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n",
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n",
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n",
" tir.reads(A_1[i, k], B_1[j, k])\n",
" tir.writes(C_global_1[i, j])\n",
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n",
" with tir.init():\n",
" C_global_1[i, j] = tir.float32(0)\n",
" C_global_1[i, j] = C_global_1[i, j] + A_1[i, k] * B_1[j, k]\n",
" for ax0, ax1 in tir.grid(32, 16):\n",
" with tir.block(\"C_global\"):\n",
" v0 = tir.axis.spatial(32, ax0)\n",
" v1 = tir.axis.spatial(32, j0_0 * 16 + ax1)\n",
" tir.reads(C_global_1[v0, v1])\n",
" tir.writes(C_1[v0, v1])\n",
" C_1[v0, v1] = C_global_1[v0, v1]\n",
" \n",
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n",
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n",
"l2, l3, l4 = sch.get_loops(block=b0)\n",
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n",
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n",
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n",
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n",
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n",
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n",
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n",
"b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope=\"global\")\n",
"sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n",
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v26)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #2:\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n",
" # body\n",
" with tir.block(\"root\"):\n",
" tir.reads()\n",
" tir.writes()\n",
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":512, \"meta_schedule.vectorize\":64})\n",
" for i0_0, j0_0, i0_1, j0_1, k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(1, 2, 4, 2, 4, 8, 1, 8, 1, 8):\n",
" with tir.block():\n",
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n",
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n",
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n",
" tir.reads(A_1[i, k], B_1[j, k])\n",
" tir.writes(C_1[i, j])\n",
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n",
" with tir.init():\n",
" C_1[i, j] = tir.float32(0)\n",
" C_1[i, j] = C_1[i, j] + A_1[i, k] * B_1[j, k]\n",
" \n",
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n",
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n",
"l2, l3, l4 = sch.get_loops(block=b0)\n",
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n",
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n",
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n",
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n",
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n",
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n",
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n",
"v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=3)\n",
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v25)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:96: Initializing Task #1: \"tir_relu\"\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:102: \n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_relu\"})\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i, j in tir.grid(32, 32):\n",
" with tir.block():\n",
" vi, vj = tir.axis.remap(\"SS\", [i, j])\n",
" tir.reads(A_1[vi, vj])\n",
" tir.writes(B_1[vi, vj])\n",
" B_1[vi, vj] = tir.max(A_1[vi, vj], tir.float32(0))\n",
" \n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:106: Total 1 design space(s) generated\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #0:\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"tir_relu\"})\n",
" # body\n",
" with tir.block(\"root\"):\n",
" tir.reads()\n",
" tir.writes()\n",
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":512, \"meta_schedule.vectorize\":64})\n",
" for i, j in tir.grid(32, 32):\n",
" with tir.block():\n",
" vi, vj = tir.axis.remap(\"SS\", [i, j])\n",
" tir.reads(A_1[vi, vj])\n",
" tir.writes(B_1[vi, vj])\n",
" B_1[vi, vj] = tir.max(A_1[vi, vj], tir.float32(0))\n",
" \n",
"b0 = sch.get_block(name=\"root\", func_name=\"main\")\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.parallel\", ann_val=256)\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n",
"v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=3)\n",
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v1)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:111: \n",
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n",
"----------------------------------------------------------------------------------------------------------------\n",
" 0 | tir_matmul | 65536 | 1 | N/A | N/A | N/A | 0 | \n",
" 1 | tir_relu | 1024 | 1 | N/A | N/A | N/A | 0 | \n",
"----------------------------------------------------------------------------------------------------------------\n",
"Total trials: 0\n",
"Total latency (us): 0\n",
"\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:125: Scheduler picks Task #0: \"tir_matmul\"\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:656: Generating candidates......\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:658: Picked top 0 candidate(s) from database\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:494: Sample-Init-Population summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:660: Sampled 2048 candidate(s)\n",
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #0 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n",
"[09:22:12] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #1 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n",
"[09:22:12] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #2 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #3 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:595: Scores of the best 2 candidates:\n",
"[1 : 2]:\t0.9999 0.9999\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:664: Got 2 candidate(s) with evolutionary search\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:666: Sending 2 candidates(s) for measurement\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:32: Sending 2 sample(s) to builder\n",
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:53: Sending 2 sample(s) to runner\n",
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:125: Scheduler picks Task #1: \"tir_relu\"\n",
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:656: Generating candidates......\n",
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:658: Picked top 0 candidate(s) from database\n",
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:494: Sample-Init-Population summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n",
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:660: Sampled 2048 candidate(s)\n",
"[09:22:15] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #0 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n",
"[09:22:15] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #1 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n",
"[09:22:16] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #2 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #3 done. Summary:\n",
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n",
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n",
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:595: Scores of the best 2 candidates:\n",
"[1 : 2]:\t0.8936 0.6800\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:664: Got 2 candidate(s) with evolutionary search\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:666: Sending 2 candidates(s) for measurement\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:32: Sending 2 sample(s) to builder\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:53: Sending 2 sample(s) to runner\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #0: \"tir_matmul\"] Trial #0: GFLOPs: 3.4616. Time: 0.0189 ms. Best GFLOPs: 3.4616\n",
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #0: \"tir_matmul\"] Trial #1: GFLOPs: 4.3685. Time: 0.0150 ms. Best GFLOPs: 4.3685\n",
"/home/yuchenj/.local/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release\n",
" PkgResourcesDeprecationWarning,\n",
"/home/yuchenj/.local/lib/python3.7/site-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated. See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html\n",
" warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:172: [Updated] Task #0: \"tir_matmul\"\n",
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n",
"----------------------------------------------------------------------------------------------------------------\n",
" 0 | tir_matmul | 65536 | 1 | 4.3685 | 15.0020 | 15.0020 | 2 | \n",
" 1 | tir_relu | 1024 | 1 | N/A | N/A | N/A | 0 | \n",
"----------------------------------------------------------------------------------------------------------------\n",
"Total trials: 2\n",
"Total latency (us): 15.002\n",
"\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:149: Task #0 has finished. Remaining task(s): 1\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #1: \"tir_relu\"] Trial #0: GFLOPs: 0.0400. Time: 0.0256 ms. Best GFLOPs: 0.0400\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #1: \"tir_relu\"] Trial #1: GFLOPs: 0.0836. Time: 0.0122 ms. Best GFLOPs: 0.0836\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:172: [Updated] Task #1: \"tir_relu\"\n",
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n",
"----------------------------------------------------------------------------------------------------------------\n",
" 0 | tir_matmul | 65536 | 1 | 4.3685 | 15.0020 | 15.0020 | 2 | Y \n",
" 1 | tir_relu | 1024 | 1 | 0.0836 | 12.2439 | 12.2439 | 2 | \n",
"----------------------------------------------------------------------------------------------------------------\n",
"Total trials: 4\n",
"Total latency (us): 27.2459\n",
"\n",
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:149: Task #1 has finished. Remaining task(s): 0\n"
]
},
{
"data": {
"text/plain": [
"<tvm.nd.NDArray shape=(32, 32), cpu(0)>\n",
"array([[ 9.073223 , 9.206748 , 8.944179 , ..., 10.142985 , 9.609654 ,\n",
" 9.918516 ],\n",
" [ 8.463486 , 9.127991 , 8.411135 , ..., 8.769029 , 8.403664 ,\n",
" 9.8124075],\n",
" [ 7.24664 , 7.115847 , 7.343585 , ..., 8.223137 , 7.9694247,\n",
" 9.62473 ],\n",
" ...,\n",
" [ 8.291355 , 8.538074 , 7.597375 , ..., 8.382829 , 8.619025 ,\n",
" 9.292543 ],\n",
" [ 8.903806 , 9.442456 , 7.942852 , ..., 8.686344 , 7.853592 ,\n",
" 10.73006 ],\n",
" [ 7.6947894, 8.574124 , 7.977135 , ..., 8.067622 , 7.8615704,\n",
" 9.567776 ]], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from tvm import meta_schedule as ms\n",
"import tempfile\n",
"from tvm.meta_schedule.testing import DummyDatabase\n",
"\n",
"\n",
"database = DummyDatabase()\n",
"\n",
"@tvm.script.ir_module\n",
"class InputModule:\n",
" @T.prim_func\n",
" def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:\n",
" T.func_attr({\"global_symbol\": \"tir_matmul\"})\n",
" m = T.var(\"int32\")\n",
" n = T.var(\"int32\")\n",
" k = T.var(\"int32\")\n",
" A = T.match_buffer(x, (32, 32))\n",
" B = T.match_buffer(y, (32, 32))\n",
" C = T.match_buffer(z, (32, 32))\n",
"\n",
" for (i0, j0, k0) in T.grid(32, 32, 32):\n",
" with T.block():\n",
" i, j, k = T.axis.remap(\"SSR\", [i0, j0, k0])\n",
" with T.init():\n",
" C[i, j] = 0.0\n",
" C[i, j] += A[i, k] * B[j, k]\n",
"\n",
" @T.prim_func\n",
" def tir_relu(x: T.handle, y: T.handle):\n",
" T.func_attr({\"global_symbol\": \"tir_relu\"})\n",
" m = T.var(\"int32\")\n",
" n = T.var(\"int32\")\n",
" A = T.match_buffer(x, (32, 32))\n",
" B = T.match_buffer(y, (32, 32))\n",
" for (i, j) in T.grid(32, 32):\n",
" with T.block():\n",
" vi, vj = T.axis.remap(\"SS\", [i, j])\n",
" B[vi, vj] = T.max(A[vi, vj], 0.0)\n",
"\n",
" @R.function\n",
" def main(x: Tensor((32, 32), \"float32\"), w: Tensor((32, 32), \"float32\")) -> Tensor:\n",
" with R.dataflow():\n",
" lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype=\"float32\")\n",
" lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype=\"float32\")\n",
" relax.output(lv1)\n",
" return lv1\n",
"\n",
"mod = InputModule\n",
"target = Target(\"llvm --num-cores=16\")\n",
"dev = tvm.cpu()\n",
"database = DummyDatabase()\n",
"\n",
"with tempfile.TemporaryDirectory() as work_dir:\n",
" relax_ex = ms.tune_relax(\n",
" mod=mod,\n",
" target=target,\n",
" config=ms.EvolutionarySearchConfig(\n",
" num_trials_per_iter=2,\n",
" max_trials_per_task=4,\n",
" max_trials_global=4,\n",
" ),\n",
" work_dir=work_dir,\n",
" database=database,\n",
" )\n",
"\n",
"vm = relax.VirtualMachine(relax_ex, dev)\n",
"data = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev)\n",
"weight = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev)\n",
"vm[\"main\"](data, weight)"
]
},
{
"cell_type": "markdown",
"id": "cc0e909f",
"metadata": {},
"source": [
"## Relax minimum build"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b3ab5492",
"metadata": {},
"outputs": [],
"source": [
"@tvm.register_func(\"test.vm.tile\")\n",
"def tile_packed(a, b):\n",
" b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dfedf857",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"Original Relax Program\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n",
" # block 0\n",
" with relax.dataflow():\n",
" relax.match_shape(x, (n, m))\n",
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n",
" relax.output(y)\n",
" return y\n",
" \n"
]
}
],
"source": [
"@tvm.script.ir_module\n",
"class MyIRModule:\n",
" @R.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor:\n",
" with R.dataflow():\n",
" R.match_shape(x, (n, m))\n",
" y = R.call_tir(\"test.vm.tile\", (x), (n, m * 2), dtype=\"float32\")\n",
" R.output(y)\n",
" return y\n",
"\n",
"# Original Relax Program\n",
"print(\"======================\")\n",
"print(\"Original Relax Program\\n\")\n",
"mod = MyIRModule\n",
"code = R.parser.astext(mod)\n",
"print(code)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "9fb478f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS0: To Non Dataflow\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n",
" # block 0\n",
" relax.match_shape(x, (n, m))\n",
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n",
" return y\n",
" \n"
]
}
],
"source": [
"# ToNonDataflow Pass\n",
"print(\"======================\")\n",
"print(\"PASS0: To Non Dataflow\\n\")\n",
"mod = relax.transform.ToNonDataflow()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4221356d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS1: CallDPS Rewrite\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n",
" # block 0\n",
" relax.match_shape(x, (n, m))\n",
" alloc = relax.builtin.alloc_tensor((n, (m * 2)), dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.AllocTensorAttrs\")\n",
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# CallDPS Rewrite\n",
"print(\"======================\")\n",
"print(\"PASS1: CallDPS Rewrite\\n\")\n",
"mod = relax.transform.CallTIRRewrite()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f13f5dc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS2: Memory Lower\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n",
" # block 0\n",
" relax.match_shape(x, (n, m))\n",
" storage = relax.vm.builtin.alloc_storage((((n * (m * 2)) * 4),), dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n",
" tensor = relax.vm.builtin.alloc_tensor(storage, (n, (m * 2)), offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n",
" alloc = tensor\n",
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# Memory Lower\n",
"print(\"======================\")\n",
"print(\"PASS2: Memory Lower\\n\")\n",
"mod = relax.transform.VMMemoryLower()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "82592858",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS3: Shape Lower\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def shape_func(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"shape_func\"})\n",
" # body\n",
" H_1[2] = H_1[0] * (H_1[1] * tir.int64(2)) * tir.int64(4)\n",
" \n",
" @tir.prim_func\n",
" def shape_func1(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"shape_func1\"})\n",
" # body\n",
" H_1[0] = H_1[0]\n",
" H_1[3] = H_1[1] * tir.int64(2)\n",
" \n",
" @relax.function\n",
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n",
" # block 0\n",
" shape_heap: Tensor((4,), \"int64\") = relax.call_packed(\"vm.builtin.alloc_shape_heap\", (4,))\n",
" # block 1\n",
" sh = relax.call_packed(\"vm.builtin.shape_of\", x)\n",
" gv = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" _ = shape_func(shape_heap)\n",
" sh1 = relax.vm.builtin.load_shape(shape_heap, indices=[2], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" storage = relax.vm.builtin.alloc_storage(sh1, dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n",
" _1 = shape_func1(shape_heap)\n",
" sh2 = relax.vm.builtin.load_shape(shape_heap, indices=[0, 3], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" tensor = relax.vm.builtin.alloc_tensor(storage, sh2, offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n",
" alloc = tensor\n",
" _2 = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# Shape Lower\n",
"print(\"======================\")\n",
"print(\"PASS3: Shape Lower\\n\")\n",
"mod = relax.transform.VMShapeLower()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "63e9ad0a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"Build & Execute\n",
"input: [[0.0434758 0.54549795 0.01151199 0.26422125]\n",
" [0.06118515 0.20250504 0.9074851 0.0011075 ]\n",
" [0.09346122 0.71868443 0.22026683 0.20938304]]\n",
"output: [[0.0434758 0.54549795 0.01151199 0.26422125 0.0434758 0.54549795\n",
" 0.01151199 0.26422125]\n",
" [0.06118515 0.20250504 0.9074851 0.0011075 0.06118515 0.20250504\n",
" 0.9074851 0.0011075 ]\n",
" [0.09346122 0.71868443 0.22026683 0.20938304 0.09346122 0.71868443\n",
" 0.22026683 0.20938304]]\n"
]
}
],
"source": [
"# Build & Execute\n",
"print(\"======================\")\n",
"print(\"Build & Execute\")\n",
"\n",
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())\n",
"\n",
"shape = (3, 4)\n",
"inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))\n",
"out = vm[\"foo\"](inp)\n",
"print(\"input: \", inp)\n",
"print(\"output: \", out)\n",
"np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), out.asnumpy())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment