Last active
October 8, 2024 15:21
-
-
Save YuchenJin/56442c4e967f68c20e5777e46fe0a68d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d5783aee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import annotations # must import to defer parsing of annotations\n", | |
"import os\n", | |
"import numpy as np\n", | |
"import tvm\n", | |
"from tvm import relax, tir, topi\n", | |
"from tvm.runtime import container\n", | |
"from tvm.target.target import Target\n", | |
"from tvm.relax.testing import nn\n", | |
"\n", | |
"import tvm.script\n", | |
"from tvm.script import tir as T, relax as R" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "62b4198a", | |
"metadata": {}, | |
"source": [ | |
"## Build and run a neural network in Relax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "fd0cc6e9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"builder = relax.BlockBuilder()\n", | |
"\n", | |
"input_size = 784\n", | |
"hidden_sizes = [128, 64]\n", | |
"output_size = 10" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "068cded9", | |
"metadata": {}, | |
"source": [ | |
"Build a three linear-layer neural network for a classification task.\n", | |
"A neural network is a `nn.Module` that can consist of other modules (layers). This nested structure allows for building complex models easily." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "a12e45da", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with builder.function(name=\"main\"):\n", | |
" model = nn.Sequential(\n", | |
" nn.Linear(input_size, hidden_sizes[0]),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_sizes[1], output_size),\n", | |
" nn.LogSoftmax(),\n", | |
" )\n", | |
" \n", | |
" # n is a symbolic variable to represent a dynamic batch size\n", | |
" n = tir.Var(\"n\", \"int64\")\n", | |
" data = nn.Placeholder((n, input_size), name=\"data\")\n", | |
" output = model(data)\n", | |
" params = [data] + model.parameters()\n", | |
" builder.emit_func_output(output, params=params) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "abde5a76", | |
"metadata": {}, | |
"source": [ | |
"Get and print the IRModule being built." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "54f1bb9e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def matmul1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128, 64), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 64], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 64, 128):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n", | |
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def matmul(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(784, 128), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 784], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 128, 784):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n", | |
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def log_softmax(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"log_softmax\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n", | |
" compute_3 = tir.match_buffer(var_compute_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" compute_4 = tir.alloc_buffer([n_1], dtype=\"float32\")\n", | |
" compute_5 = tir.alloc_buffer([n_1], dtype=\"float32\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute\"):\n", | |
" i = tir.axis.spatial(n_1, i0)\n", | |
" k = tir.axis.reduce(10, i1)\n", | |
" tir.reads(rxplaceholder_1[i, k])\n", | |
" tir.writes(compute_4[i])\n", | |
" with tir.init():\n", | |
" compute_4[i] = tir.float32(-3.4028234663852886e+38)\n", | |
" compute_4[i] = tir.max(compute_4[i], rxplaceholder_1[i, k])\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute_1\"):\n", | |
" i = tir.axis.spatial(n_1, i0)\n", | |
" k = tir.axis.reduce(10, i1)\n", | |
" tir.reads(rxplaceholder_1[i, k], compute_4[i])\n", | |
" tir.writes(compute_5[i])\n", | |
" with tir.init():\n", | |
" compute_5[i] = tir.float32(0)\n", | |
" compute_5[i] = compute_5[i] + tir.exp(rxplaceholder_1[i, k] - compute_4[i], dtype=\"float32\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute_2\"):\n", | |
" i = tir.axis.spatial(n_1, i0)\n", | |
" j = tir.axis.spatial(10, i1)\n", | |
" tir.reads(rxplaceholder_1[i, j], compute_4[i], compute_5[i])\n", | |
" tir.writes(compute_3[i, j])\n", | |
" compute_3[i, j] = rxplaceholder_1[i, j] - compute_4[i] - tir.log(compute_5[i], dtype=\"float32\")\n", | |
" \n", | |
" @relax.function\n", | |
" def main(data: Tensor((n, 784), \"float32\"), linear_weight: Tensor((784, 128), \"float32\"), linear_bias: Tensor((128,), \"float32\"), linear_weight1: Tensor((128, 64), \"float32\"), linear_bias1: Tensor((64,), \"float32\"), linear_weight2: Tensor((64, 10), \"float32\"), linear_bias2: Tensor((10,), \"float32\")) -> Tensor(_, \"float32\"):\n", | |
" # block 0\n", | |
" gv = relax.call_tir(matmul, (data, linear_weight), (n, 128), dtype=\"float32\")\n", | |
" gv1 = relax.call_tir(add, (gv, linear_bias), (n, 128), dtype=\"float32\")\n", | |
" gv2 = relax.call_tir(relu, (gv1,), (n, 128), dtype=\"float32\")\n", | |
" gv3 = relax.call_tir(matmul1, (gv2, linear_weight1), (n, 64), dtype=\"float32\")\n", | |
" gv4 = relax.call_tir(add1, (gv3, linear_bias1), (n, 64), dtype=\"float32\")\n", | |
" gv5 = relax.call_tir(relu1, (gv4,), (n, 64), dtype=\"float32\")\n", | |
" gv6 = relax.call_tir(matmul2, (gv5, linear_weight2), (n, 10), dtype=\"float32\")\n", | |
" gv7 = relax.call_tir(add2, (gv6, linear_bias2), (n, 10), dtype=\"float32\")\n", | |
" gv8 = relax.call_tir(log_softmax, (gv7,), (n, 10), dtype=\"float32\")\n", | |
" return gv8\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu1(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 64], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 64):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2 = tir.axis.spatial(n_1, i0)\n", | |
" i1_2 = tir.axis.spatial(64, i1)\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(10,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1 = tir.axis.spatial(10, i1)\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(64,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 64], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 64):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1 = tir.axis.spatial(64, i1)\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 128):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1 = tir.axis.spatial(128, i1)\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 128):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2 = tir.axis.spatial(n_1, i0)\n", | |
" i1_2 = tir.axis.spatial(128, i1)\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def matmul2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(64, 10), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul2\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 64], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 10, 64):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0 = tir.axis.spatial(n_1, i0)\n", | |
" ax1, k = tir.axis.remap(\"SR\", [i1, i2])\n", | |
" tir.reads(rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"mod = builder.get()\n", | |
"R.parser.pretty_print(mod)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "158875f7", | |
"metadata": {}, | |
"source": [ | |
"Build the model and create Relax VM." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b6e7a8cd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# build and create vm executor\n", | |
"target = tvm.target.Target(\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7679f077", | |
"metadata": {}, | |
"source": [ | |
"Run the model on Relax VM." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1e530019", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[-109873.59 -49647. -41688.25 -94398.22 -81715.94 -111547.\n", | |
" 0. -63041.28 -18236.625 -98630.66 ]\n", | |
" [-110056.69 -49696.03 -41759.594 -94529.28 -81796.625 -111659.78\n", | |
" 0. -63112.375 -18241.969 -98737.47 ]\n", | |
" [-109510.03 -49518.72 -41588.156 -94110.03 -81418.28 -111148.59\n", | |
" 0. -62861.28 -18204.406 -98300.47 ]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# init parameters\n", | |
"params = nn.init_params(mod)\n", | |
"# the input data has a minibatch size of 3\n", | |
"data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))\n", | |
"\n", | |
"res = vm[\"main\"](data, *params)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "dfe290fb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[-113639.97 -51407.094 -43109.22 -97656.59 -84459.34 -115332.\n", | |
" 0. -65198.156 -18852.219 -102011.34 ]\n", | |
" [-110885.53 -50159.812 -42090. -95289.5 -82462.25 -112593.28\n", | |
" 0. -63644.656 -18419.344 -99547.97 ]\n", | |
" [-111952.81 -50647.22 -42490.312 -96203.5 -83224.56 -113655.34\n", | |
" 0. -64285.72 -18587.406 -100505.625]\n", | |
" [-114726.34 -51872.094 -43565.125 -98605.84 -85311.03 -116471.19\n", | |
" 0. -65849.66 -19040.094 -102954.72 ]\n", | |
" [-108532.47 -49076.97 -41242.25 -93290.25 -80707.25 -110214.56\n", | |
" 0. -62300.625 -18015.938 -97434.66 ]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# the input data has a minibatch size of 5\n", | |
"data = tvm.nd.array(np.random.rand(5, input_size).astype(np.float32))\n", | |
"\n", | |
"res = vm[\"main\"](data, *params)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "23b4c322", | |
"metadata": {}, | |
"source": [ | |
"Define a layer/network with emit_te using the nn.Module interface." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "eb687277", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Linear(nn.Module):\n", | |
" \"\"\"Applies a linear transformation to the input data: :math:`y = xA + b`.\"\"\"\n", | |
"\n", | |
" def __init__(self, in_features, out_features, bias=True):\n", | |
" self.in_features = in_features\n", | |
" self.out_features = out_features\n", | |
" self.weight = Parameter((in_features, out_features), name=\"linear_weight\")\n", | |
" if bias:\n", | |
" self.bias = Parameter((out_features,), name=\"linear_bias\")\n", | |
" else:\n", | |
" self.bias = None\n", | |
"\n", | |
" def forward(self, input: relax.Expr) -> relax.Var:\n", | |
" y = emit_te(topi.matmul, input, self.weight)\n", | |
" if self.bias is not None:\n", | |
" y = emit_te(topi.add, y, self.bias)\n", | |
" return y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "89e095a9", | |
"metadata": {}, | |
"source": [ | |
"## TE/TOPI Integration" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "7ca3cfdb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def build_mlp(data, weight):\n", | |
" bb = relax.BlockBuilder()\n", | |
"\n", | |
" with bb.function(\"mlp\", [data, weight]):\n", | |
" gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)\n", | |
" gv1 = bb.emit_te(topi.nn.relu, gv0)\n", | |
" bb.emit_func_output(gv1)\n", | |
"\n", | |
" mod = bb.get()\n", | |
" return mod" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "c5c1e8f2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def matmul(var_rxplaceholder_2: tir.handle, var_rxplaceholder_3: tir.handle, var_C_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n", | |
" m_1 = tir.var(\"int64\")\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_2 = tir.match_buffer(var_rxplaceholder_2, [n_1, m_1], dtype=\"float32\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_3, [m_1, n_1], dtype=\"float32\")\n", | |
" C_1 = tir.match_buffer(var_C_1, [n_1, n_1], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" tir.evaluate(tir.tvm_call_packed(\"tvm.contrib.cblas.matmul\", tir.tvm_stack_make_array(rxplaceholder_2.data, tir.tvm_stack_make_shape(n_1, m_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(rxplaceholder_3.data, tir.tvm_stack_make_shape(m_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(C_1.data, tir.tvm_stack_make_shape(n_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), False, False, dtype=\"int32\"))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, n_1], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, n_1], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, n_1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2 = tir.axis.spatial(n_1, i0)\n", | |
" i1_2 = tir.axis.spatial(n_1, i1)\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @relax.function\n", | |
" def mlp(data: Tensor((n, m), \"float32\"), weight: Tensor((m, n), \"float32\")) -> Tensor(_, \"float32\"):\n", | |
" # block 0\n", | |
" gv = relax.call_tir(matmul, (data, weight), (n, n), dtype=\"float32\")\n", | |
" gv1 = relax.call_tir(relu, (gv,), (n, n), dtype=\"float32\")\n", | |
" return gv1\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# symbolic dimensions\n", | |
"n, m = tir.Var(\"n\", \"int64\"), tir.Var(\"m\", \"int64\")\n", | |
"\n", | |
"# create data and weight variables\n", | |
"data = relax.Var(\"data\", [n, m], relax.DynTensorType(2, \"float32\"))\n", | |
"weight = relax.Var(\"weight\", [m, n], relax.DynTensorType(2, \"float32\"))\n", | |
"\n", | |
"# construct a mlp model\n", | |
"mod = build_mlp(data, weight)\n", | |
"R.parser.pretty_print(mod)\n", | |
"\n", | |
"# build and create vm executor\n", | |
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "8102fabf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[ 8.095244 7.7235837 7.2910867 8.443724 8.958878 9.012167\n", | |
" 6.2120357 10.1474695 8.469089 7.993631 7.2318916 7.1488924\n", | |
" 9.181463 9.888946 7.4828773 7.393664 ]\n", | |
" [ 6.051395 7.281189 7.8760085 8.753442 6.9901752 7.286955\n", | |
" 5.8768206 9.506951 7.4303255 7.2454524 6.817876 6.3926682\n", | |
" 7.413857 8.555092 7.938551 6.3793116]\n", | |
" [ 7.050069 8.146633 7.577522 9.216323 7.63874 8.516393\n", | |
" 6.0171905 9.377093 8.423015 7.169124 6.4793515 5.9576015\n", | |
" 8.156402 9.256693 7.9188857 7.347392 ]\n", | |
" [ 6.8375144 7.0806613 7.232546 6.8441987 6.581527 8.293472\n", | |
" 5.9637613 8.6074705 7.3677044 7.7873907 6.419702 6.1378307\n", | |
" 6.6358232 8.800402 7.43657 7.4201636]\n", | |
" [ 6.958605 8.005278 6.393196 8.5577755 7.973779 9.274346\n", | |
" 7.1687098 9.766083 7.716219 7.32963 7.0473433 6.7898445\n", | |
" 8.318464 8.446748 7.0219917 6.123752 ]\n", | |
" [ 7.7625294 7.0056577 7.2986197 7.4696107 6.741804 8.562313\n", | |
" 6.763641 9.269506 7.983401 6.6070294 6.7254267 6.6040297\n", | |
" 8.351605 9.130937 7.5254893 7.160763 ]\n", | |
" [ 7.156014 7.827703 8.12125 7.753513 8.445888 8.61097\n", | |
" 6.417519 8.823943 8.172567 7.9088287 6.9773703 6.460421\n", | |
" 7.808111 8.590609 8.104791 7.838943 ]\n", | |
" [ 6.3050046 7.45814 6.4376473 7.2996607 7.829011 8.281724\n", | |
" 5.924798 8.665511 7.338019 7.0315704 6.467948 5.5320587\n", | |
" 7.4525414 8.788908 6.5465045 6.375387 ]\n", | |
" [ 7.203224 6.936556 6.9815955 7.9408994 6.9658127 8.797333\n", | |
" 6.0595207 8.919445 7.4469514 7.177114 6.276354 7.216568\n", | |
" 8.093876 8.271761 7.5531983 6.2297573]\n", | |
" [ 7.9077554 8.393251 7.7091074 7.876217 9.050663 9.165333\n", | |
" 7.29719 10.020016 9.2527 8.996158 8.42391 6.766323\n", | |
" 8.467502 10.105639 7.681791 7.9979396]\n", | |
" [ 7.722364 7.502222 7.16021 7.320424 6.938996 8.806945\n", | |
" 6.0865664 8.607358 7.289708 7.509741 6.8619833 6.346017\n", | |
" 6.9190187 9.129631 7.799539 7.013547 ]\n", | |
" [ 6.0149918 6.927125 7.1344876 7.1830897 7.465643 7.3225746\n", | |
" 5.3835173 7.936848 6.925822 7.061758 6.644145 5.838758\n", | |
" 6.670918 7.5502677 7.6008825 6.3464193]\n", | |
" [ 7.827011 7.927222 8.246812 7.77723 7.6312366 9.066271\n", | |
" 7.001202 9.651561 8.513117 8.172654 7.569078 6.1455894\n", | |
" 8.610018 9.841372 7.924851 8.502279 ]\n", | |
" [ 7.1759543 7.4548283 8.092802 8.589085 6.612156 7.836399\n", | |
" 6.296878 9.386656 7.850534 7.32955 7.614632 7.229883\n", | |
" 7.568119 9.086787 7.3559017 7.266826 ]\n", | |
" [ 6.887543 7.3447785 7.154263 6.1262264 6.5849376 7.7508464\n", | |
" 5.4122353 7.795231 7.67428 6.178569 7.021037 5.7336097\n", | |
" 6.4718122 6.8764296 6.5011544 5.798081 ]\n", | |
" [ 6.4761295 7.764174 6.3661256 7.272473 7.787151 8.531428\n", | |
" 5.380861 8.1289835 7.5046635 7.178151 5.9490094 5.8559055\n", | |
" 7.414933 7.15778 6.905289 6.256718 ]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# run the mlp model on relax vm\n", | |
"data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))\n", | |
"weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))\n", | |
"res = vm[\"mlp\"](data, weight)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5bad13ca", | |
"metadata": {}, | |
"source": [ | |
"## Relay to Relax translator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "58ecdee8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw11(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 256, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw11\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 256, 16, 16], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 16, 16):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 15 and 1 <= i3_2 and i3_2 < 15, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 256, 3, 3):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def dense(rxplaceholder_2: tir.Buffer[(1, 2048), \"float32\"], rxplaceholder_3: tir.Buffer[(1000, 2048), \"float32\"], T_matmul_NT_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"dense\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(1, 1000, 2048):\n", | |
" with tir.block(\"T_matmul_NT\"):\n", | |
" i, j, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", | |
" tir.reads(rxplaceholder_2[i, k], rxplaceholder_3[j, k])\n", | |
" tir.writes(T_matmul_NT_1[i, j])\n", | |
" tir.block_attr({\"layout_free_placeholders\":[rxplaceholder_3]})\n", | |
" with tir.init():\n", | |
" T_matmul_NT_1[i, j] = tir.float32(0)\n", | |
" T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + rxplaceholder_2[i, k] * rxplaceholder_3[j, k]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def global_avg_pool2d(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"], tensor_2: tir.Buffer[(1, 2048, 1, 1), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"global_avg_pool2d\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" tensor_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3, i4, i5 in tir.grid(1, 2048, 1, 1, 7, 7):\n", | |
" with tir.block(\"tensor\"):\n", | |
" ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap(\"SSSSRR\", [i0, i1, i2, i3, i4, i5])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1])\n", | |
" tir.writes(tensor_3[ax0, ax1, ax2, ax3])\n", | |
" with tir.init():\n", | |
" tensor_3[ax0, ax1, ax2, ax3] = tir.float32(0)\n", | |
" tensor_3[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax0, ax1, ax2 * 7 + rv0, ax3 * 7 + rv1]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"tensor_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(tensor_3[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(tensor_2[ax0, ax1, ax2, ax3])\n", | |
" tensor_2[ax0, ax1, ax2, ax3] = tensor_3[ax0, ax1, ax2, ax3] * tir.float32(0.020408163265306121)\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm1(rxplaceholder_5: tir.Buffer[(1, 64, 112, 112), \"float32\"], rxplaceholder_6: tir.Buffer[(64,), \"float32\"], rxplaceholder_7: tir.Buffer[(64,), \"float32\"], rxplaceholder_8: tir.Buffer[(64,), \"float32\"], rxplaceholder_9: tir.Buffer[(64,), \"float32\"], T_add_2: tir.Buffer[(1, 64, 112, 112), \"float32\"], T_multiply_3: tir.Buffer[(64,), \"float32\"], T_multiply_4: tir.Buffer[(64,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm1\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 64, 112, 112], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 112, 112):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 112, 112):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 112, 112):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(64):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(64, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(64):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(64, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add3(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_add_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add3\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n", | |
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw4(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw4\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 256, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add2(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_add_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n", | |
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw(rxplaceholder_2: tir.Buffer[(1, 3, 224, 224), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 3, 7, 7), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 112, 112), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 3, 230, 230], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 230, 230):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(3 <= i2_2 and i2_2 < 227 and 3 <= i3_2 and i3_2 < 227, rxplaceholder_2[i0_2, i1_2, i2_2 - 3, i3_2 - 3], tir.float32(0), dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 112, 112, 3, 7, 7):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw13(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(1024, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw13\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 512, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm7(rxplaceholder_5: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_6: tir.Buffer[(1024,), \"float32\"], rxplaceholder_7: tir.Buffer[(1024,), \"float32\"], rxplaceholder_8: tir.Buffer[(1024,), \"float32\"], rxplaceholder_9: tir.Buffer[(1024,), \"float32\"], T_add_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_multiply_3: tir.Buffer[(1024,), \"float32\"], T_multiply_4: tir.Buffer[(1024,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm7\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 1024, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 1024])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 1024]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 1024])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 1024]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 1024])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 1024]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 1024, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 1024])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 1024]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(1024):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(1024, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(1024):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(1024, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm8(rxplaceholder_5: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_6: tir.Buffer[(512,), \"float32\"], rxplaceholder_7: tir.Buffer[(512,), \"float32\"], rxplaceholder_8: tir.Buffer[(512,), \"float32\"], rxplaceholder_9: tir.Buffer[(512,), \"float32\"], T_add_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], T_multiply_3: tir.Buffer[(512,), \"float32\"], T_multiply_4: tir.Buffer[(512,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm8\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(512):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(512, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(512):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(512, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu3(rxplaceholder_1: tir.Buffer[(1, 128, 28, 28), \"float32\"], T_relu_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu3\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu1(rxplaceholder_1: tir.Buffer[(1, 64, 56, 56), \"float32\"], T_relu_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @relax.function\n", | |
" def main(data: Tensor((1, 3, 224, 224), \"float32\"), bn_data_gamma: Tensor((3,), \"float32\"), bn_data_beta: Tensor((3,), \"float32\"), bn_data_moving_mean: Tensor((3,), \"float32\"), bn_data_moving_var: Tensor((3,), \"float32\"), conv0_weight: Tensor((64, 3, 7, 7), \"float32\"), bn0_gamma: Tensor((64,), \"float32\"), bn0_beta: Tensor((64,), \"float32\"), bn0_moving_mean: Tensor((64,), \"float32\"), bn0_moving_var: Tensor((64,), \"float32\"), stage1_unit1_bn1_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn1_beta: Tensor((64,), \"float32\"), stage1_unit1_bn1_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn1_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv1_weight: Tensor((64, 64, 1, 1), \"float32\"), stage1_unit1_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn2_beta: Tensor((64,), \"float32\"), stage1_unit1_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit1_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit1_bn3_beta: Tensor((64,), \"float32\"), stage1_unit1_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit1_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit1_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit1_sc_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit2_bn1_gamma: Tensor((256,), \"float32\"), stage1_unit2_bn1_beta: Tensor((256,), \"float32\"), stage1_unit2_bn1_moving_mean: Tensor((256,), \"float32\"), stage1_unit2_bn1_moving_var: Tensor((256,), \"float32\"), stage1_unit2_conv1_weight: Tensor((64, 256, 1, 1), \"float32\"), stage1_unit2_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit2_bn2_beta: Tensor((64,), \"float32\"), stage1_unit2_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit2_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit2_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit2_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit2_bn3_beta: Tensor((64,), \"float32\"), stage1_unit2_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit2_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit2_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage1_unit3_bn1_gamma: Tensor((256,), \"float32\"), stage1_unit3_bn1_beta: Tensor((256,), \"float32\"), stage1_unit3_bn1_moving_mean: Tensor((256,), \"float32\"), stage1_unit3_bn1_moving_var: Tensor((256,), \"float32\"), stage1_unit3_conv1_weight: Tensor((64, 256, 1, 1), \"float32\"), stage1_unit3_bn2_gamma: Tensor((64,), \"float32\"), stage1_unit3_bn2_beta: Tensor((64,), \"float32\"), stage1_unit3_bn2_moving_mean: Tensor((64,), \"float32\"), stage1_unit3_bn2_moving_var: Tensor((64,), \"float32\"), stage1_unit3_conv2_weight: Tensor((64, 64, 3, 3), \"float32\"), stage1_unit3_bn3_gamma: Tensor((64,), \"float32\"), stage1_unit3_bn3_beta: Tensor((64,), \"float32\"), stage1_unit3_bn3_moving_mean: Tensor((64,), \"float32\"), stage1_unit3_bn3_moving_var: Tensor((64,), \"float32\"), stage1_unit3_conv3_weight: Tensor((256, 64, 1, 1), \"float32\"), stage2_unit1_bn1_gamma: Tensor((256,), \"float32\"), stage2_unit1_bn1_beta: Tensor((256,), \"float32\"), stage2_unit1_bn1_moving_mean: Tensor((256,), \"float32\"), stage2_unit1_bn1_moving_var: Tensor((256,), \"float32\"), stage2_unit1_conv1_weight: Tensor((128, 256, 1, 1), \"float32\"), stage2_unit1_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit1_bn2_beta: Tensor((128,), \"float32\"), stage2_unit1_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit1_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit1_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit1_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit1_bn3_beta: Tensor((128,), \"float32\"), stage2_unit1_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit1_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit1_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit1_sc_weight: Tensor((512, 256, 1, 1), \"float32\"), stage2_unit2_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit2_bn1_beta: Tensor((512,), \"float32\"), stage2_unit2_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit2_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit2_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit2_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit2_bn2_beta: Tensor((128,), \"float32\"), stage2_unit2_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit2_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit2_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit2_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit2_bn3_beta: Tensor((128,), \"float32\"), stage2_unit2_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit2_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit2_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit3_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit3_bn1_beta: Tensor((512,), \"float32\"), stage2_unit3_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit3_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit3_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit3_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit3_bn2_beta: Tensor((128,), \"float32\"), stage2_unit3_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit3_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit3_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit3_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit3_bn3_beta: Tensor((128,), \"float32\"), stage2_unit3_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit3_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit3_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage2_unit4_bn1_gamma: Tensor((512,), \"float32\"), stage2_unit4_bn1_beta: Tensor((512,), \"float32\"), stage2_unit4_bn1_moving_mean: Tensor((512,), \"float32\"), stage2_unit4_bn1_moving_var: Tensor((512,), \"float32\"), stage2_unit4_conv1_weight: Tensor((128, 512, 1, 1), \"float32\"), stage2_unit4_bn2_gamma: Tensor((128,), \"float32\"), stage2_unit4_bn2_beta: Tensor((128,), \"float32\"), stage2_unit4_bn2_moving_mean: Tensor((128,), \"float32\"), stage2_unit4_bn2_moving_var: Tensor((128,), \"float32\"), stage2_unit4_conv2_weight: Tensor((128, 128, 3, 3), \"float32\"), stage2_unit4_bn3_gamma: Tensor((128,), \"float32\"), stage2_unit4_bn3_beta: Tensor((128,), \"float32\"), stage2_unit4_bn3_moving_mean: Tensor((128,), \"float32\"), stage2_unit4_bn3_moving_var: Tensor((128,), \"float32\"), stage2_unit4_conv3_weight: Tensor((512, 128, 1, 1), \"float32\"), stage3_unit1_bn1_gamma: Tensor((512,), \"float32\"), stage3_unit1_bn1_beta: Tensor((512,), \"float32\"), stage3_unit1_bn1_moving_mean: Tensor((512,), \"float32\"), stage3_unit1_bn1_moving_var: Tensor((512,), \"float32\"), stage3_unit1_conv1_weight: Tensor((256, 512, 1, 1), \"float32\"), stage3_unit1_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit1_bn2_beta: Tensor((256,), \"float32\"), stage3_unit1_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit1_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit1_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit1_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit1_bn3_beta: Tensor((256,), \"float32\"), stage3_unit1_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit1_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit1_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit1_sc_weight: Tensor((1024, 512, 1, 1), \"float32\"), stage3_unit2_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit2_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit2_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit2_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit2_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit2_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit2_bn2_beta: Tensor((256,), \"float32\"), stage3_unit2_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit2_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit2_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit2_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit2_bn3_beta: Tensor((256,), \"float32\"), stage3_unit2_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit2_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit2_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit3_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit3_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit3_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit3_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit3_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit3_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit3_bn2_beta: Tensor((256,), \"float32\"), stage3_unit3_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit3_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit3_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit3_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit3_bn3_beta: Tensor((256,), \"float32\"), stage3_unit3_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit3_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit3_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit4_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit4_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit4_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit4_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit4_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit4_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit4_bn2_beta: Tensor((256,), \"float32\"), stage3_unit4_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit4_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit4_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit4_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit4_bn3_beta: Tensor((256,), \"float32\"), stage3_unit4_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit4_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit4_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit5_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit5_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit5_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit5_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit5_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit5_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit5_bn2_beta: Tensor((256,), \"float32\"), stage3_unit5_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit5_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit5_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit5_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit5_bn3_beta: Tensor((256,), \"float32\"), stage3_unit5_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit5_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit5_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage3_unit6_bn1_gamma: Tensor((1024,), \"float32\"), stage3_unit6_bn1_beta: Tensor((1024,), \"float32\"), stage3_unit6_bn1_moving_mean: Tensor((1024,), \"float32\"), stage3_unit6_bn1_moving_var: Tensor((1024,), \"float32\"), stage3_unit6_conv1_weight: Tensor((256, 1024, 1, 1), \"float32\"), stage3_unit6_bn2_gamma: Tensor((256,), \"float32\"), stage3_unit6_bn2_beta: Tensor((256,), \"float32\"), stage3_unit6_bn2_moving_mean: Tensor((256,), \"float32\"), stage3_unit6_bn2_moving_var: Tensor((256,), \"float32\"), stage3_unit6_conv2_weight: Tensor((256, 256, 3, 3), \"float32\"), stage3_unit6_bn3_gamma: Tensor((256,), \"float32\"), stage3_unit6_bn3_beta: Tensor((256,), \"float32\"), stage3_unit6_bn3_moving_mean: Tensor((256,), \"float32\"), stage3_unit6_bn3_moving_var: Tensor((256,), \"float32\"), stage3_unit6_conv3_weight: Tensor((1024, 256, 1, 1), \"float32\"), stage4_unit1_bn1_gamma: Tensor((1024,), \"float32\"), stage4_unit1_bn1_beta: Tensor((1024,), \"float32\"), stage4_unit1_bn1_moving_mean: Tensor((1024,), \"float32\"), stage4_unit1_bn1_moving_var: Tensor((1024,), \"float32\"), stage4_unit1_conv1_weight: Tensor((512, 1024, 1, 1), \"float32\"), stage4_unit1_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit1_bn2_beta: Tensor((512,), \"float32\"), stage4_unit1_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit1_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit1_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit1_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit1_bn3_beta: Tensor((512,), \"float32\"), stage4_unit1_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit1_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit1_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), stage4_unit1_sc_weight: Tensor((2048, 1024, 1, 1), \"float32\"), stage4_unit2_bn1_gamma: Tensor((2048,), \"float32\"), stage4_unit2_bn1_beta: Tensor((2048,), \"float32\"), stage4_unit2_bn1_moving_mean: Tensor((2048,), \"float32\"), stage4_unit2_bn1_moving_var: Tensor((2048,), \"float32\"), stage4_unit2_conv1_weight: Tensor((512, 2048, 1, 1), \"float32\"), stage4_unit2_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit2_bn2_beta: Tensor((512,), \"float32\"), stage4_unit2_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit2_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit2_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit2_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit2_bn3_beta: Tensor((512,), \"float32\"), stage4_unit2_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit2_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit2_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), stage4_unit3_bn1_gamma: Tensor((2048,), \"float32\"), stage4_unit3_bn1_beta: Tensor((2048,), \"float32\"), stage4_unit3_bn1_moving_mean: Tensor((2048,), \"float32\"), stage4_unit3_bn1_moving_var: Tensor((2048,), \"float32\"), stage4_unit3_conv1_weight: Tensor((512, 2048, 1, 1), \"float32\"), stage4_unit3_bn2_gamma: Tensor((512,), \"float32\"), stage4_unit3_bn2_beta: Tensor((512,), \"float32\"), stage4_unit3_bn2_moving_mean: Tensor((512,), \"float32\"), stage4_unit3_bn2_moving_var: Tensor((512,), \"float32\"), stage4_unit3_conv2_weight: Tensor((512, 512, 3, 3), \"float32\"), stage4_unit3_bn3_gamma: Tensor((512,), \"float32\"), stage4_unit3_bn3_beta: Tensor((512,), \"float32\"), stage4_unit3_bn3_moving_mean: Tensor((512,), \"float32\"), stage4_unit3_bn3_moving_var: Tensor((512,), \"float32\"), stage4_unit3_conv3_weight: Tensor((2048, 512, 1, 1), \"float32\"), bn1_gamma: Tensor((2048,), \"float32\"), bn1_beta: Tensor((2048,), \"float32\"), bn1_moving_mean: Tensor((2048,), \"float32\"), bn1_moving_var: Tensor((2048,), \"float32\"), fc1_weight: Tensor((1000, 2048), \"float32\"), fc1_bias: Tensor((1000,), \"float32\")) -> Tensor(_, \"float32\"):\n", | |
" # block 0\n", | |
" gv = relax.call_tir(batch_norm, (data, bn_data_gamma, bn_data_beta, bn_data_moving_mean, bn_data_moving_var), ((1, 3, 224, 224), (3,), (3,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv1: Tensor((1, 3, 224, 224), \"float32\") = gv[0]\n", | |
" gv2 = relax.call_tir(conv2d_nchw, (gv1, conv0_weight), (1, 64, 112, 112), dtype=\"float32\")\n", | |
" gv3 = relax.call_tir(batch_norm1, (gv2, bn0_gamma, bn0_beta, bn0_moving_mean, bn0_moving_var), ((1, 64, 112, 112), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv4: Tensor((1, 64, 112, 112), \"float32\") = gv3[0]\n", | |
" gv5 = relax.call_tir(relu, (gv4,), (1, 64, 112, 112), dtype=\"float32\")\n", | |
" gv6 = relax.call_tir(max_pool2d, (gv5,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv7 = relax.call_tir(batch_norm2, (gv6, stage1_unit1_bn1_gamma, stage1_unit1_bn1_beta, stage1_unit1_bn1_moving_mean, stage1_unit1_bn1_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv8: Tensor((1, 64, 56, 56), \"float32\") = gv7[0]\n", | |
" gv9 = relax.call_tir(relu1, (gv8,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv10 = relax.call_tir(conv2d_nchw1, (gv9, stage1_unit1_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv11 = relax.call_tir(batch_norm2, (gv10, stage1_unit1_bn2_gamma, stage1_unit1_bn2_beta, stage1_unit1_bn2_moving_mean, stage1_unit1_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv12: Tensor((1, 64, 56, 56), \"float32\") = gv11[0]\n", | |
" gv13 = relax.call_tir(relu1, (gv12,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv14 = relax.call_tir(conv2d_nchw2, (gv13, stage1_unit1_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv15 = relax.call_tir(batch_norm2, (gv14, stage1_unit1_bn3_gamma, stage1_unit1_bn3_beta, stage1_unit1_bn3_moving_mean, stage1_unit1_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv16: Tensor((1, 64, 56, 56), \"float32\") = gv15[0]\n", | |
" gv17 = relax.call_tir(relu1, (gv16,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv18 = relax.call_tir(conv2d_nchw3, (gv17, stage1_unit1_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv19 = relax.call_tir(conv2d_nchw3, (gv9, stage1_unit1_sc_weight), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv20 = relax.call_tir(add, (gv18, gv19), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv21 = relax.call_tir(batch_norm3, (gv20, stage1_unit2_bn1_gamma, stage1_unit2_bn1_beta, stage1_unit2_bn1_moving_mean, stage1_unit2_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv22: Tensor((1, 256, 56, 56), \"float32\") = gv21[0]\n", | |
" gv23 = relax.call_tir(relu2, (gv22,), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv24 = relax.call_tir(conv2d_nchw4, (gv23, stage1_unit2_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv25 = relax.call_tir(batch_norm2, (gv24, stage1_unit2_bn2_gamma, stage1_unit2_bn2_beta, stage1_unit2_bn2_moving_mean, stage1_unit2_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv26: Tensor((1, 64, 56, 56), \"float32\") = gv25[0]\n", | |
" gv27 = relax.call_tir(relu1, (gv26,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv28 = relax.call_tir(conv2d_nchw2, (gv27, stage1_unit2_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv29 = relax.call_tir(batch_norm2, (gv28, stage1_unit2_bn3_gamma, stage1_unit2_bn3_beta, stage1_unit2_bn3_moving_mean, stage1_unit2_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv30: Tensor((1, 64, 56, 56), \"float32\") = gv29[0]\n", | |
" gv31 = relax.call_tir(relu1, (gv30,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv32 = relax.call_tir(conv2d_nchw3, (gv31, stage1_unit2_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv33 = relax.call_tir(add, (gv32, gv20), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv34 = relax.call_tir(batch_norm3, (gv33, stage1_unit3_bn1_gamma, stage1_unit3_bn1_beta, stage1_unit3_bn1_moving_mean, stage1_unit3_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv35: Tensor((1, 256, 56, 56), \"float32\") = gv34[0]\n", | |
" gv36 = relax.call_tir(relu2, (gv35,), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv37 = relax.call_tir(conv2d_nchw4, (gv36, stage1_unit3_conv1_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv38 = relax.call_tir(batch_norm2, (gv37, stage1_unit3_bn2_gamma, stage1_unit3_bn2_beta, stage1_unit3_bn2_moving_mean, stage1_unit3_bn2_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv39: Tensor((1, 64, 56, 56), \"float32\") = gv38[0]\n", | |
" gv40 = relax.call_tir(relu1, (gv39,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv41 = relax.call_tir(conv2d_nchw2, (gv40, stage1_unit3_conv2_weight), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv42 = relax.call_tir(batch_norm2, (gv41, stage1_unit3_bn3_gamma, stage1_unit3_bn3_beta, stage1_unit3_bn3_moving_mean, stage1_unit3_bn3_moving_var), ((1, 64, 56, 56), (64,), (64,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv43: Tensor((1, 64, 56, 56), \"float32\") = gv42[0]\n", | |
" gv44 = relax.call_tir(relu1, (gv43,), (1, 64, 56, 56), dtype=\"float32\")\n", | |
" gv45 = relax.call_tir(conv2d_nchw3, (gv44, stage1_unit3_conv3_weight), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv46 = relax.call_tir(add, (gv45, gv33), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv47 = relax.call_tir(batch_norm3, (gv46, stage2_unit1_bn1_gamma, stage2_unit1_bn1_beta, stage2_unit1_bn1_moving_mean, stage2_unit1_bn1_moving_var), ((1, 256, 56, 56), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv48: Tensor((1, 256, 56, 56), \"float32\") = gv47[0]\n", | |
" gv49 = relax.call_tir(relu2, (gv48,), (1, 256, 56, 56), dtype=\"float32\")\n", | |
" gv50 = relax.call_tir(conv2d_nchw5, (gv49, stage2_unit1_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv51 = relax.call_tir(batch_norm4, (gv50, stage2_unit1_bn2_gamma, stage2_unit1_bn2_beta, stage2_unit1_bn2_moving_mean, stage2_unit1_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv52: Tensor((1, 128, 28, 28), \"float32\") = gv51[0]\n", | |
" gv53 = relax.call_tir(relu3, (gv52,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv54 = relax.call_tir(conv2d_nchw6, (gv53, stage2_unit1_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv55 = relax.call_tir(batch_norm4, (gv54, stage2_unit1_bn3_gamma, stage2_unit1_bn3_beta, stage2_unit1_bn3_moving_mean, stage2_unit1_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv56: Tensor((1, 128, 28, 28), \"float32\") = gv55[0]\n", | |
" gv57 = relax.call_tir(relu3, (gv56,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv58 = relax.call_tir(conv2d_nchw7, (gv57, stage2_unit1_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv59 = relax.call_tir(conv2d_nchw8, (gv49, stage2_unit1_sc_weight), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv60 = relax.call_tir(add1, (gv58, gv59), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv61 = relax.call_tir(batch_norm5, (gv60, stage2_unit2_bn1_gamma, stage2_unit2_bn1_beta, stage2_unit2_bn1_moving_mean, stage2_unit2_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv62: Tensor((1, 512, 28, 28), \"float32\") = gv61[0]\n", | |
" gv63 = relax.call_tir(relu4, (gv62,), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv64 = relax.call_tir(conv2d_nchw9, (gv63, stage2_unit2_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv65 = relax.call_tir(batch_norm4, (gv64, stage2_unit2_bn2_gamma, stage2_unit2_bn2_beta, stage2_unit2_bn2_moving_mean, stage2_unit2_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv66: Tensor((1, 128, 28, 28), \"float32\") = gv65[0]\n", | |
" gv67 = relax.call_tir(relu3, (gv66,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv68 = relax.call_tir(conv2d_nchw6, (gv67, stage2_unit2_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv69 = relax.call_tir(batch_norm4, (gv68, stage2_unit2_bn3_gamma, stage2_unit2_bn3_beta, stage2_unit2_bn3_moving_mean, stage2_unit2_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv70: Tensor((1, 128, 28, 28), \"float32\") = gv69[0]\n", | |
" gv71 = relax.call_tir(relu3, (gv70,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv72 = relax.call_tir(conv2d_nchw7, (gv71, stage2_unit2_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv73 = relax.call_tir(add1, (gv72, gv60), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv74 = relax.call_tir(batch_norm5, (gv73, stage2_unit3_bn1_gamma, stage2_unit3_bn1_beta, stage2_unit3_bn1_moving_mean, stage2_unit3_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv75: Tensor((1, 512, 28, 28), \"float32\") = gv74[0]\n", | |
" gv76 = relax.call_tir(relu4, (gv75,), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv77 = relax.call_tir(conv2d_nchw9, (gv76, stage2_unit3_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv78 = relax.call_tir(batch_norm4, (gv77, stage2_unit3_bn2_gamma, stage2_unit3_bn2_beta, stage2_unit3_bn2_moving_mean, stage2_unit3_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv79: Tensor((1, 128, 28, 28), \"float32\") = gv78[0]\n", | |
" gv80 = relax.call_tir(relu3, (gv79,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv81 = relax.call_tir(conv2d_nchw6, (gv80, stage2_unit3_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv82 = relax.call_tir(batch_norm4, (gv81, stage2_unit3_bn3_gamma, stage2_unit3_bn3_beta, stage2_unit3_bn3_moving_mean, stage2_unit3_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv83: Tensor((1, 128, 28, 28), \"float32\") = gv82[0]\n", | |
" gv84 = relax.call_tir(relu3, (gv83,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv85 = relax.call_tir(conv2d_nchw7, (gv84, stage2_unit3_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv86 = relax.call_tir(add1, (gv85, gv73), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv87 = relax.call_tir(batch_norm5, (gv86, stage2_unit4_bn1_gamma, stage2_unit4_bn1_beta, stage2_unit4_bn1_moving_mean, stage2_unit4_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv88: Tensor((1, 512, 28, 28), \"float32\") = gv87[0]\n", | |
" gv89 = relax.call_tir(relu4, (gv88,), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv90 = relax.call_tir(conv2d_nchw9, (gv89, stage2_unit4_conv1_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv91 = relax.call_tir(batch_norm4, (gv90, stage2_unit4_bn2_gamma, stage2_unit4_bn2_beta, stage2_unit4_bn2_moving_mean, stage2_unit4_bn2_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv92: Tensor((1, 128, 28, 28), \"float32\") = gv91[0]\n", | |
" gv93 = relax.call_tir(relu3, (gv92,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv94 = relax.call_tir(conv2d_nchw6, (gv93, stage2_unit4_conv2_weight), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv95 = relax.call_tir(batch_norm4, (gv94, stage2_unit4_bn3_gamma, stage2_unit4_bn3_beta, stage2_unit4_bn3_moving_mean, stage2_unit4_bn3_moving_var), ((1, 128, 28, 28), (128,), (128,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv96: Tensor((1, 128, 28, 28), \"float32\") = gv95[0]\n", | |
" gv97 = relax.call_tir(relu3, (gv96,), (1, 128, 28, 28), dtype=\"float32\")\n", | |
" gv98 = relax.call_tir(conv2d_nchw7, (gv97, stage2_unit4_conv3_weight), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv99 = relax.call_tir(add1, (gv98, gv86), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv100 = relax.call_tir(batch_norm5, (gv99, stage3_unit1_bn1_gamma, stage3_unit1_bn1_beta, stage3_unit1_bn1_moving_mean, stage3_unit1_bn1_moving_var), ((1, 512, 28, 28), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv101: Tensor((1, 512, 28, 28), \"float32\") = gv100[0]\n", | |
" gv102 = relax.call_tir(relu4, (gv101,), (1, 512, 28, 28), dtype=\"float32\")\n", | |
" gv103 = relax.call_tir(conv2d_nchw10, (gv102, stage3_unit1_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv104 = relax.call_tir(batch_norm6, (gv103, stage3_unit1_bn2_gamma, stage3_unit1_bn2_beta, stage3_unit1_bn2_moving_mean, stage3_unit1_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv105: Tensor((1, 256, 14, 14), \"float32\") = gv104[0]\n", | |
" gv106 = relax.call_tir(relu5, (gv105,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv107 = relax.call_tir(conv2d_nchw11, (gv106, stage3_unit1_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv108 = relax.call_tir(batch_norm6, (gv107, stage3_unit1_bn3_gamma, stage3_unit1_bn3_beta, stage3_unit1_bn3_moving_mean, stage3_unit1_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv109: Tensor((1, 256, 14, 14), \"float32\") = gv108[0]\n", | |
" gv110 = relax.call_tir(relu5, (gv109,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv111 = relax.call_tir(conv2d_nchw12, (gv110, stage3_unit1_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv112 = relax.call_tir(conv2d_nchw13, (gv102, stage3_unit1_sc_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv113 = relax.call_tir(add2, (gv111, gv112), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv114 = relax.call_tir(batch_norm7, (gv113, stage3_unit2_bn1_gamma, stage3_unit2_bn1_beta, stage3_unit2_bn1_moving_mean, stage3_unit2_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv115: Tensor((1, 1024, 14, 14), \"float32\") = gv114[0]\n", | |
" gv116 = relax.call_tir(relu6, (gv115,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv117 = relax.call_tir(conv2d_nchw14, (gv116, stage3_unit2_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv118 = relax.call_tir(batch_norm6, (gv117, stage3_unit2_bn2_gamma, stage3_unit2_bn2_beta, stage3_unit2_bn2_moving_mean, stage3_unit2_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv119: Tensor((1, 256, 14, 14), \"float32\") = gv118[0]\n", | |
" gv120 = relax.call_tir(relu5, (gv119,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv121 = relax.call_tir(conv2d_nchw11, (gv120, stage3_unit2_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv122 = relax.call_tir(batch_norm6, (gv121, stage3_unit2_bn3_gamma, stage3_unit2_bn3_beta, stage3_unit2_bn3_moving_mean, stage3_unit2_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv123: Tensor((1, 256, 14, 14), \"float32\") = gv122[0]\n", | |
" gv124 = relax.call_tir(relu5, (gv123,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv125 = relax.call_tir(conv2d_nchw12, (gv124, stage3_unit2_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv126 = relax.call_tir(add2, (gv125, gv113), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv127 = relax.call_tir(batch_norm7, (gv126, stage3_unit3_bn1_gamma, stage3_unit3_bn1_beta, stage3_unit3_bn1_moving_mean, stage3_unit3_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv128: Tensor((1, 1024, 14, 14), \"float32\") = gv127[0]\n", | |
" gv129 = relax.call_tir(relu6, (gv128,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv130 = relax.call_tir(conv2d_nchw14, (gv129, stage3_unit3_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv131 = relax.call_tir(batch_norm6, (gv130, stage3_unit3_bn2_gamma, stage3_unit3_bn2_beta, stage3_unit3_bn2_moving_mean, stage3_unit3_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv132: Tensor((1, 256, 14, 14), \"float32\") = gv131[0]\n", | |
" gv133 = relax.call_tir(relu5, (gv132,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv134 = relax.call_tir(conv2d_nchw11, (gv133, stage3_unit3_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv135 = relax.call_tir(batch_norm6, (gv134, stage3_unit3_bn3_gamma, stage3_unit3_bn3_beta, stage3_unit3_bn3_moving_mean, stage3_unit3_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv136: Tensor((1, 256, 14, 14), \"float32\") = gv135[0]\n", | |
" gv137 = relax.call_tir(relu5, (gv136,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv138 = relax.call_tir(conv2d_nchw12, (gv137, stage3_unit3_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv139 = relax.call_tir(add2, (gv138, gv126), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv140 = relax.call_tir(batch_norm7, (gv139, stage3_unit4_bn1_gamma, stage3_unit4_bn1_beta, stage3_unit4_bn1_moving_mean, stage3_unit4_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv141: Tensor((1, 1024, 14, 14), \"float32\") = gv140[0]\n", | |
" gv142 = relax.call_tir(relu6, (gv141,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv143 = relax.call_tir(conv2d_nchw14, (gv142, stage3_unit4_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv144 = relax.call_tir(batch_norm6, (gv143, stage3_unit4_bn2_gamma, stage3_unit4_bn2_beta, stage3_unit4_bn2_moving_mean, stage3_unit4_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv145: Tensor((1, 256, 14, 14), \"float32\") = gv144[0]\n", | |
" gv146 = relax.call_tir(relu5, (gv145,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv147 = relax.call_tir(conv2d_nchw11, (gv146, stage3_unit4_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv148 = relax.call_tir(batch_norm6, (gv147, stage3_unit4_bn3_gamma, stage3_unit4_bn3_beta, stage3_unit4_bn3_moving_mean, stage3_unit4_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv149: Tensor((1, 256, 14, 14), \"float32\") = gv148[0]\n", | |
" gv150 = relax.call_tir(relu5, (gv149,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv151 = relax.call_tir(conv2d_nchw12, (gv150, stage3_unit4_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv152 = relax.call_tir(add2, (gv151, gv139), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv153 = relax.call_tir(batch_norm7, (gv152, stage3_unit5_bn1_gamma, stage3_unit5_bn1_beta, stage3_unit5_bn1_moving_mean, stage3_unit5_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv154: Tensor((1, 1024, 14, 14), \"float32\") = gv153[0]\n", | |
" gv155 = relax.call_tir(relu6, (gv154,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv156 = relax.call_tir(conv2d_nchw14, (gv155, stage3_unit5_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv157 = relax.call_tir(batch_norm6, (gv156, stage3_unit5_bn2_gamma, stage3_unit5_bn2_beta, stage3_unit5_bn2_moving_mean, stage3_unit5_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv158: Tensor((1, 256, 14, 14), \"float32\") = gv157[0]\n", | |
" gv159 = relax.call_tir(relu5, (gv158,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv160 = relax.call_tir(conv2d_nchw11, (gv159, stage3_unit5_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv161 = relax.call_tir(batch_norm6, (gv160, stage3_unit5_bn3_gamma, stage3_unit5_bn3_beta, stage3_unit5_bn3_moving_mean, stage3_unit5_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv162: Tensor((1, 256, 14, 14), \"float32\") = gv161[0]\n", | |
" gv163 = relax.call_tir(relu5, (gv162,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv164 = relax.call_tir(conv2d_nchw12, (gv163, stage3_unit5_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv165 = relax.call_tir(add2, (gv164, gv152), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv166 = relax.call_tir(batch_norm7, (gv165, stage3_unit6_bn1_gamma, stage3_unit6_bn1_beta, stage3_unit6_bn1_moving_mean, stage3_unit6_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv167: Tensor((1, 1024, 14, 14), \"float32\") = gv166[0]\n", | |
" gv168 = relax.call_tir(relu6, (gv167,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv169 = relax.call_tir(conv2d_nchw14, (gv168, stage3_unit6_conv1_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv170 = relax.call_tir(batch_norm6, (gv169, stage3_unit6_bn2_gamma, stage3_unit6_bn2_beta, stage3_unit6_bn2_moving_mean, stage3_unit6_bn2_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv171: Tensor((1, 256, 14, 14), \"float32\") = gv170[0]\n", | |
" gv172 = relax.call_tir(relu5, (gv171,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv173 = relax.call_tir(conv2d_nchw11, (gv172, stage3_unit6_conv2_weight), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv174 = relax.call_tir(batch_norm6, (gv173, stage3_unit6_bn3_gamma, stage3_unit6_bn3_beta, stage3_unit6_bn3_moving_mean, stage3_unit6_bn3_moving_var), ((1, 256, 14, 14), (256,), (256,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv175: Tensor((1, 256, 14, 14), \"float32\") = gv174[0]\n", | |
" gv176 = relax.call_tir(relu5, (gv175,), (1, 256, 14, 14), dtype=\"float32\")\n", | |
" gv177 = relax.call_tir(conv2d_nchw12, (gv176, stage3_unit6_conv3_weight), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv178 = relax.call_tir(add2, (gv177, gv165), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv179 = relax.call_tir(batch_norm7, (gv178, stage4_unit1_bn1_gamma, stage4_unit1_bn1_beta, stage4_unit1_bn1_moving_mean, stage4_unit1_bn1_moving_var), ((1, 1024, 14, 14), (1024,), (1024,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv180: Tensor((1, 1024, 14, 14), \"float32\") = gv179[0]\n", | |
" gv181 = relax.call_tir(relu6, (gv180,), (1, 1024, 14, 14), dtype=\"float32\")\n", | |
" gv182 = relax.call_tir(conv2d_nchw15, (gv181, stage4_unit1_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv183 = relax.call_tir(batch_norm8, (gv182, stage4_unit1_bn2_gamma, stage4_unit1_bn2_beta, stage4_unit1_bn2_moving_mean, stage4_unit1_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv184: Tensor((1, 512, 7, 7), \"float32\") = gv183[0]\n", | |
" gv185 = relax.call_tir(relu7, (gv184,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv186 = relax.call_tir(conv2d_nchw16, (gv185, stage4_unit1_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv187 = relax.call_tir(batch_norm8, (gv186, stage4_unit1_bn3_gamma, stage4_unit1_bn3_beta, stage4_unit1_bn3_moving_mean, stage4_unit1_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv188: Tensor((1, 512, 7, 7), \"float32\") = gv187[0]\n", | |
" gv189 = relax.call_tir(relu7, (gv188,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv190 = relax.call_tir(conv2d_nchw17, (gv189, stage4_unit1_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv191 = relax.call_tir(conv2d_nchw18, (gv181, stage4_unit1_sc_weight), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv192 = relax.call_tir(add3, (gv190, gv191), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv193 = relax.call_tir(batch_norm9, (gv192, stage4_unit2_bn1_gamma, stage4_unit2_bn1_beta, stage4_unit2_bn1_moving_mean, stage4_unit2_bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv194: Tensor((1, 2048, 7, 7), \"float32\") = gv193[0]\n", | |
" gv195 = relax.call_tir(relu8, (gv194,), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv196 = relax.call_tir(conv2d_nchw19, (gv195, stage4_unit2_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv197 = relax.call_tir(batch_norm8, (gv196, stage4_unit2_bn2_gamma, stage4_unit2_bn2_beta, stage4_unit2_bn2_moving_mean, stage4_unit2_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv198: Tensor((1, 512, 7, 7), \"float32\") = gv197[0]\n", | |
" gv199 = relax.call_tir(relu7, (gv198,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv200 = relax.call_tir(conv2d_nchw16, (gv199, stage4_unit2_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv201 = relax.call_tir(batch_norm8, (gv200, stage4_unit2_bn3_gamma, stage4_unit2_bn3_beta, stage4_unit2_bn3_moving_mean, stage4_unit2_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv202: Tensor((1, 512, 7, 7), \"float32\") = gv201[0]\n", | |
" gv203 = relax.call_tir(relu7, (gv202,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv204 = relax.call_tir(conv2d_nchw17, (gv203, stage4_unit2_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv205 = relax.call_tir(add3, (gv204, gv192), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv206 = relax.call_tir(batch_norm9, (gv205, stage4_unit3_bn1_gamma, stage4_unit3_bn1_beta, stage4_unit3_bn1_moving_mean, stage4_unit3_bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv207: Tensor((1, 2048, 7, 7), \"float32\") = gv206[0]\n", | |
" gv208 = relax.call_tir(relu8, (gv207,), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv209 = relax.call_tir(conv2d_nchw19, (gv208, stage4_unit3_conv1_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv210 = relax.call_tir(batch_norm8, (gv209, stage4_unit3_bn2_gamma, stage4_unit3_bn2_beta, stage4_unit3_bn2_moving_mean, stage4_unit3_bn2_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv211: Tensor((1, 512, 7, 7), \"float32\") = gv210[0]\n", | |
" gv212 = relax.call_tir(relu7, (gv211,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv213 = relax.call_tir(conv2d_nchw16, (gv212, stage4_unit3_conv2_weight), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv214 = relax.call_tir(batch_norm8, (gv213, stage4_unit3_bn3_gamma, stage4_unit3_bn3_beta, stage4_unit3_bn3_moving_mean, stage4_unit3_bn3_moving_var), ((1, 512, 7, 7), (512,), (512,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv215: Tensor((1, 512, 7, 7), \"float32\") = gv214[0]\n", | |
" gv216 = relax.call_tir(relu7, (gv215,), (1, 512, 7, 7), dtype=\"float32\")\n", | |
" gv217 = relax.call_tir(conv2d_nchw17, (gv216, stage4_unit3_conv3_weight), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv218 = relax.call_tir(add3, (gv217, gv205), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv219 = relax.call_tir(batch_norm9, (gv218, bn1_gamma, bn1_beta, bn1_moving_mean, bn1_moving_var), ((1, 2048, 7, 7), (2048,), (2048,)), dtype=(\"float32\", \"float32\", \"float32\"))\n", | |
" gv220: Tensor((1, 2048, 7, 7), \"float32\") = gv219[0]\n", | |
" gv221 = relax.call_tir(relu8, (gv220,), (1, 2048, 7, 7), dtype=\"float32\")\n", | |
" gv222 = relax.call_tir(global_avg_pool2d, (gv221,), (1, 2048, 1, 1), dtype=\"float32\")\n", | |
" gv223 = relax.call_tir(batch_flatten, (gv222,), (1, 2048), dtype=\"float32\")\n", | |
" gv224 = relax.call_tir(dense, (gv223, fc1_weight), (1, 1000), dtype=\"float32\")\n", | |
" gv225 = relax.call_tir(bias_add, (gv224, fc1_bias), (1, 1000), dtype=\"float32\")\n", | |
" gv226 = relax.call_tir(softmax, (gv225,), (1, 1000), dtype=\"float32\")\n", | |
" return gv226\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add1(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_add_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n", | |
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw3(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 64, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw3\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 56, 56, 64, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw19(rxplaceholder_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 2048, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw19\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 2048, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def softmax(rxplaceholder_1: tir.Buffer[(1, 1000), \"float32\"], T_softmax_norm_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"softmax\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_softmax_maxelem_1 = tir.alloc_buffer([1], dtype=\"float32\")\n", | |
" T_softmax_exp_1 = tir.alloc_buffer([1, 1000], dtype=\"float32\")\n", | |
" T_softmax_expsum_1 = tir.alloc_buffer([1], dtype=\"float32\")\n", | |
" for i0_7, i1_3 in tir.grid(1, 1000):\n", | |
" with tir.block(\"T_softmax_maxelem\"):\n", | |
" i0_8, k = tir.axis.remap(\"SR\", [i0_7, i1_3])\n", | |
" tir.reads(rxplaceholder_1[i0_8, k])\n", | |
" tir.writes(T_softmax_maxelem_1[i0_8])\n", | |
" with tir.init():\n", | |
" T_softmax_maxelem_1[i0_8] = tir.float32(-3.4028234663852886e+38)\n", | |
" T_softmax_maxelem_1[i0_8] = tir.max(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])\n", | |
" for i0_9, i1_4 in tir.grid(1, 1000):\n", | |
" with tir.block(\"T_softmax_exp\"):\n", | |
" i0_10, i1_5 = tir.axis.remap(\"SS\", [i0_9, i1_4])\n", | |
" tir.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])\n", | |
" tir.writes(T_softmax_exp_1[i0_10, i1_5])\n", | |
" T_softmax_exp_1[i0_10, i1_5] = tir.exp(rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype=\"float32\")\n", | |
" for i0_11, i1_6 in tir.grid(1, 1000):\n", | |
" with tir.block(\"T_softmax_expsum\"):\n", | |
" i0_12, k = tir.axis.remap(\"SR\", [i0_11, i1_6])\n", | |
" tir.reads(T_softmax_exp_1[i0_12, k])\n", | |
" tir.writes(T_softmax_expsum_1[i0_12])\n", | |
" with tir.init():\n", | |
" T_softmax_expsum_1[i0_12] = tir.float32(0)\n", | |
" T_softmax_expsum_1[i0_12] = T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]\n", | |
" for i0_13, i1_7 in tir.grid(1, 1000):\n", | |
" with tir.block(\"T_softmax_norm\"):\n", | |
" i0_14, i1_8 = tir.axis.remap(\"SS\", [i0_13, i1_7])\n", | |
" tir.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])\n", | |
" tir.writes(T_softmax_norm_1[i0_14, i1_8])\n", | |
" tir.block_attr({\"axis\":1})\n", | |
" T_softmax_norm_1[i0_14, i1_8] = T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm9(rxplaceholder_5: tir.Buffer[(1, 2048, 7, 7), \"float32\"], rxplaceholder_6: tir.Buffer[(2048,), \"float32\"], rxplaceholder_7: tir.Buffer[(2048,), \"float32\"], rxplaceholder_8: tir.Buffer[(2048,), \"float32\"], rxplaceholder_9: tir.Buffer[(2048,), \"float32\"], T_add_2: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_multiply_3: tir.Buffer[(2048,), \"float32\"], T_multiply_4: tir.Buffer[(2048,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm9\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 2048, 7, 7], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 2048, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 2048])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 2048]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 2048])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 2048]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 2048])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 2048]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 2048, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 2048])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 2048]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(2048):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(2048, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(2048):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(2048, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm6(rxplaceholder_5: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_6: tir.Buffer[(256,), \"float32\"], rxplaceholder_7: tir.Buffer[(256,), \"float32\"], rxplaceholder_8: tir.Buffer[(256,), \"float32\"], rxplaceholder_9: tir.Buffer[(256,), \"float32\"], T_add_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], T_multiply_3: tir.Buffer[(256,), \"float32\"], T_multiply_4: tir.Buffer[(256,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm6\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(256):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(256, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(256):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(256, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw7(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 128, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw7\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 128, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def max_pool2d(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), \"float32\"], tensor_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"max_pool2d\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 64, 114, 114], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 114, 114):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])\n", | |
" tir.writes(pad_temp_1[ax0, ax1, ax2, ax3])\n", | |
" pad_temp_1[ax0, ax1, ax2, ax3] = tir.if_then_else(1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113, rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1], tir.float32(-3.4028234663852886e+38), dtype=\"float32\")\n", | |
" for i0, i1, i2, i3, i4, i5 in tir.grid(1, 64, 56, 56, 3, 3):\n", | |
" with tir.block(\"tensor\"):\n", | |
" ax0, ax1, ax2, ax3, rv0, rv1 = tir.axis.remap(\"SSSSRR\", [i0, i1, i2, i3, i4, i5])\n", | |
" tir.reads(pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])\n", | |
" tir.writes(tensor_1[ax0, ax1, ax2, ax3])\n", | |
" with tir.init():\n", | |
" tensor_1[ax0, ax1, ax2, ax3] = tir.float32(-3.4028234663852886e+38)\n", | |
" tensor_1[ax0, ax1, ax2, ax3] = tir.max(tensor_1[ax0, ax1, ax2, ax3], pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1])\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw6(rxplaceholder_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 128, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw6\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 128, 30, 30], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 30, 30):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 29 and 1 <= i3_2 and i3_2 < 29, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 128, 3, 3):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw1(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 64, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw1\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm5(rxplaceholder_5: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_6: tir.Buffer[(512,), \"float32\"], rxplaceholder_7: tir.Buffer[(512,), \"float32\"], rxplaceholder_8: tir.Buffer[(512,), \"float32\"], rxplaceholder_9: tir.Buffer[(512,), \"float32\"], T_add_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_multiply_3: tir.Buffer[(512,), \"float32\"], T_multiply_4: tir.Buffer[(512,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm5\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 512, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 512, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 512])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 512]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(512):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(512, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(512):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(512, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm3(rxplaceholder_5: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_6: tir.Buffer[(256,), \"float32\"], rxplaceholder_7: tir.Buffer[(256,), \"float32\"], rxplaceholder_8: tir.Buffer[(256,), \"float32\"], rxplaceholder_9: tir.Buffer[(256,), \"float32\"], T_add_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_multiply_3: tir.Buffer[(256,), \"float32\"], T_multiply_4: tir.Buffer[(256,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm3\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 256, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 256, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 256])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 256]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(256):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(256, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(256):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(256, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu(rxplaceholder_1: tir.Buffer[(1, 64, 112, 112), \"float32\"], T_relu_1: tir.Buffer[(1, 64, 112, 112), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 112, 112):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw10(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw10\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 512, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw5(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw5\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 256, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu5(rxplaceholder_1: tir.Buffer[(1, 256, 14, 14), \"float32\"], T_relu_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu5\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw15(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw15\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 1024, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw9(rxplaceholder_2: tir.Buffer[(1, 512, 28, 28), \"float32\"], rxplaceholder_3: tir.Buffer[(128, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 128, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw9\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 512, 28, 28], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 128, 28, 28, 512, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw17(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(2048, 512, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw17\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 512, 7, 7], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 512, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm(rxplaceholder_5: tir.Buffer[(1, 3, 224, 224), \"float32\"], rxplaceholder_6: tir.Buffer[(3,), \"float32\"], rxplaceholder_7: tir.Buffer[(3,), \"float32\"], rxplaceholder_8: tir.Buffer[(3,), \"float32\"], rxplaceholder_9: tir.Buffer[(3,), \"float32\"], T_add_2: tir.Buffer[(1, 3, 224, 224), \"float32\"], T_multiply_2: tir.Buffer[(3,), \"float32\"], T_multiply_3: tir.Buffer[(3,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_3 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 3, 224, 224], dtype=\"float32\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 3, 224, 224], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 3, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 3])\n", | |
" tir.writes(T_reshape_3[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_3[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 3]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 224, 224):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_3[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_3[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 3])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 3]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_4[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 3, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 3, 224, 224):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 3, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 3])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 3]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 3, 224, 224):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_5[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] + T_reshape_5[ax0, ax1, 0, 0]\n", | |
" for i0_6 in tir.serial(3):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0 = tir.axis.spatial(3, i0_6)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_2[ax0])\n", | |
" T_multiply_2[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_7 in tir.serial(3):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(3, i0_7)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw14(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(256, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 256, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw14\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 256, 14, 14, 1024, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu8(rxplaceholder_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"], T_relu_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu8\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 2048, 7, 7):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu4(rxplaceholder_1: tir.Buffer[(1, 512, 28, 28), \"float32\"], T_relu_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu4\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 28, 28):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_flatten(rxplaceholder_1: tir.Buffer[(1, 2048, 1, 1), \"float32\"], tensor_1: tir.Buffer[(1, 2048), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_flatten\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(1, 2048):\n", | |
" with tir.block(\"tensor\"):\n", | |
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1 % 2048, 0, 0])\n", | |
" tir.writes(tensor_1[ax0, ax1])\n", | |
" tensor_1[ax0, ax1] = rxplaceholder_1[ax0, ax1 % 2048, 0, 0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu6(rxplaceholder_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"], T_relu_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu6\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def bias_add(rxplaceholder_2: tir.Buffer[(1, 1000), \"float32\"], rxplaceholder_3: tir.Buffer[(1000,), \"float32\"], T_add_1: tir.Buffer[(1, 1000), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"bias_add\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(1, 1000):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu2(rxplaceholder_1: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_relu_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu2\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw12(rxplaceholder_2: tir.Buffer[(1, 256, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(1024, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 1024, 14, 14), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw12\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 256, 14, 14], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 14, 14):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 1024, 14, 14, 256, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw16(rxplaceholder_2: tir.Buffer[(1, 512, 7, 7), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 512, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw16\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 512, 9, 9], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 9, 9):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 8 and 1 <= i3_2 and i3_2 < 8, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 7, 7, 512, 3, 3):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw8(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(512, 256, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 512, 28, 28), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw8\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 256, 56, 56], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 512, 28, 28, 256, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm4(rxplaceholder_5: tir.Buffer[(1, 128, 28, 28), \"float32\"], rxplaceholder_6: tir.Buffer[(128,), \"float32\"], rxplaceholder_7: tir.Buffer[(128,), \"float32\"], rxplaceholder_8: tir.Buffer[(128,), \"float32\"], rxplaceholder_9: tir.Buffer[(128,), \"float32\"], T_add_2: tir.Buffer[(1, 128, 28, 28), \"float32\"], T_multiply_3: tir.Buffer[(128,), \"float32\"], T_multiply_4: tir.Buffer[(128,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm4\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 128, 28, 28], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 128, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 128])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 128]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 128])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 128]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 128])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 128]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 128, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 128])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 128]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 128, 28, 28):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(128):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(128, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(128):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(128, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add(rxplaceholder_2: tir.Buffer[(1, 256, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(1, 256, 56, 56), \"float32\"], T_add_1: tir.Buffer[(1, 256, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 256, 56, 56):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[ax0, ax1, ax2, ax3], rxplaceholder_3[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_1[ax0, ax1, ax2, ax3])\n", | |
" T_add_1[ax0, ax1, ax2, ax3] = rxplaceholder_2[ax0, ax1, ax2, ax3] + rxplaceholder_3[ax0, ax1, ax2, ax3]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def batch_norm2(rxplaceholder_5: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_6: tir.Buffer[(64,), \"float32\"], rxplaceholder_7: tir.Buffer[(64,), \"float32\"], rxplaceholder_8: tir.Buffer[(64,), \"float32\"], rxplaceholder_9: tir.Buffer[(64,), \"float32\"], T_add_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], T_multiply_3: tir.Buffer[(64,), \"float32\"], T_multiply_4: tir.Buffer[(64,), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"batch_norm2\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" T_reshape_4 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_subtract_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_5 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_add_3 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" compute_1 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_divide_1 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_6 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" T_multiply_5 = tir.alloc_buffer([1, 64, 56, 56], dtype=\"float32\")\n", | |
" T_reshape_7 = tir.alloc_buffer([1, 64, 1, 1], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_8[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_4[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_4[ax0, ax1, ax2, ax3] = rxplaceholder_8[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"T_subtract\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_5[ax0, ax1, ax2, ax3], T_reshape_4[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_subtract_1[ax0, ax1, ax2, ax3])\n", | |
" T_subtract_1[ax0, ax1, ax2, ax3] = rxplaceholder_5[ax0, ax1, ax2, ax3] - T_reshape_4[ax0, ax1, 0, 0]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_9[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_5[ax0, ax1, ax2, ax3] = rxplaceholder_9[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_reshape_5[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_add_3[ax0, ax1, ax2, ax3])\n", | |
" T_add_3[ax0, ax1, ax2, ax3] = T_reshape_5[ax0, ax1, ax2, ax3] + tir.float32(1.9999999494757503e-05)\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(T_add_3[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" compute_1[i0_2, i1_2, i2_2, i3_2] = tir.sqrt(T_add_3[i0_2, i1_2, i2_2, i3_2], dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"T_divide\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_3, i1_3, i2_3, i3_3])\n", | |
" tir.reads(T_subtract_1[ax0, ax1, ax2, ax3], compute_1[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_divide_1[ax0, ax1, ax2, ax3])\n", | |
" T_divide_1[ax0, ax1, ax2, ax3] = T_subtract_1[ax0, ax1, ax2, ax3] / compute_1[ax0, ax1, 0, 0]\n", | |
" for i0_4, i1_4, i2_4, i3_4 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_2\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_4, i1_4, i2_4, i3_4])\n", | |
" tir.reads(rxplaceholder_6[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_6[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_6[ax0, ax1, ax2, ax3] = rxplaceholder_6[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0_5, i1_5, i2_5, i3_5 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"T_multiply\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_5, i1_5, i2_5, i3_5])\n", | |
" tir.reads(T_divide_1[ax0, ax1, ax2, ax3], T_reshape_6[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_multiply_5[ax0, ax1, ax2, ax3])\n", | |
" T_multiply_5[ax0, ax1, ax2, ax3] = T_divide_1[ax0, ax1, ax2, ax3] * T_reshape_6[ax0, ax1, 0, 0]\n", | |
" for i0_6, i1_6, i2_6, i3_6 in tir.grid(1, 64, 1, 1):\n", | |
" with tir.block(\"T_reshape_3\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_6, i1_6, i2_6, i3_6])\n", | |
" tir.reads(rxplaceholder_7[(ax1 + ax2 + ax3) % 64])\n", | |
" tir.writes(T_reshape_7[ax0, ax1, ax2, ax3])\n", | |
" T_reshape_7[ax0, ax1, ax2, ax3] = rxplaceholder_7[(ax1 + ax2 + ax3) % 64]\n", | |
" for i0_7, i1_7, i2_7, i3_7 in tir.grid(1, 64, 56, 56):\n", | |
" with tir.block(\"T_add_1\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0_7, i1_7, i2_7, i3_7])\n", | |
" tir.reads(T_multiply_5[ax0, ax1, ax2, ax3], T_reshape_7[ax0, ax1, 0, 0])\n", | |
" tir.writes(T_add_2[ax0, ax1, ax2, ax3])\n", | |
" T_add_2[ax0, ax1, ax2, ax3] = T_multiply_5[ax0, ax1, ax2, ax3] + T_reshape_7[ax0, ax1, 0, 0]\n", | |
" for i0_8 in tir.serial(64):\n", | |
" with tir.block(\"T_multiply_1\"):\n", | |
" ax0 = tir.axis.spatial(64, i0_8)\n", | |
" tir.reads(rxplaceholder_8[ax0])\n", | |
" tir.writes(T_multiply_3[ax0])\n", | |
" T_multiply_3[ax0] = rxplaceholder_8[ax0]\n", | |
" for i0_9 in tir.serial(64):\n", | |
" with tir.block(\"T_multiply_2\"):\n", | |
" ax0 = tir.axis.spatial(64, i0_9)\n", | |
" tir.reads(rxplaceholder_9[ax0])\n", | |
" tir.writes(T_multiply_4[ax0])\n", | |
" T_multiply_4[ax0] = rxplaceholder_9[ax0]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw2(rxplaceholder_2: tir.Buffer[(1, 64, 56, 56), \"float32\"], rxplaceholder_3: tir.Buffer[(64, 64, 3, 3), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 64, 56, 56), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw2\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 64, 58, 58], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 64, 58, 58):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = tir.if_then_else(1 <= i2_2 and i2_2 < 57 and 1 <= i3_2 and i3_2 < 57, rxplaceholder_2[i0_2, i1_2, i2_2 - 1, i3_2 - 1], tir.float32(0), dtype=\"float32\")\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 64, 56, 56, 64, 3, 3):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy + ry, xx + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy + ry, xx + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu7(rxplaceholder_1: tir.Buffer[(1, 512, 7, 7), \"float32\"], T_relu_1: tir.Buffer[(1, 512, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu7\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 512, 7, 7):\n", | |
" with tir.block(\"T_relu\"):\n", | |
" ax0, ax1, ax2, ax3 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_1[ax0, ax1, ax2, ax3])\n", | |
" tir.writes(T_relu_1[ax0, ax1, ax2, ax3])\n", | |
" T_relu_1[ax0, ax1, ax2, ax3] = tir.max(rxplaceholder_1[ax0, ax1, ax2, ax3], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def conv2d_nchw18(rxplaceholder_2: tir.Buffer[(1, 1024, 14, 14), \"float32\"], rxplaceholder_3: tir.Buffer[(2048, 1024, 1, 1), \"float32\"], conv2d_nchw_1: tir.Buffer[(1, 2048, 7, 7), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"conv2d_nchw18\", \"tir.noalias\": True})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" pad_temp_1 = tir.alloc_buffer([1, 1024, 14, 14], dtype=\"float32\")\n", | |
" for i0, i1, i2, i3 in tir.grid(1, 1024, 14, 14):\n", | |
" with tir.block(\"pad_temp\"):\n", | |
" i0_2, i1_2, i2_2, i3_2 = tir.axis.remap(\"SSSS\", [i0, i1, i2, i3])\n", | |
" tir.reads(rxplaceholder_2[i0_2, i1_2, i2_2, i3_2])\n", | |
" tir.writes(pad_temp_1[i0_2, i1_2, i2_2, i3_2])\n", | |
" pad_temp_1[i0_2, i1_2, i2_2, i3_2] = rxplaceholder_2[i0_2, i1_2, i2_2, i3_2]\n", | |
" for i0_3, i1_3, i2_3, i3_3, i4, i5, i6 in tir.grid(1, 2048, 7, 7, 1024, 1, 1):\n", | |
" with tir.block(\"conv2d_nchw\"):\n", | |
" nn, ff, yy, xx, rc, ry, rx = tir.axis.remap(\"SSSSRRR\", [i0_3, i1_3, i2_3, i3_3, i4, i5, i6])\n", | |
" tir.reads(pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx], rxplaceholder_3[ff, rc, ry, rx])\n", | |
" tir.writes(conv2d_nchw_1[nn, ff, yy, xx])\n", | |
" with tir.init():\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = tir.float32(0)\n", | |
" conv2d_nchw_1[nn, ff, yy, xx] = conv2d_nchw_1[nn, ff, yy, xx] + pad_temp_1[nn, rc, yy * 2 + ry, xx * 2 + rx] * rxplaceholder_3[ff, rc, ry, rx]\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"from tvm import relay\n", | |
"import tvm.relay.testing\n", | |
"from tvm.relax.testing import relay_translator\n", | |
"\n", | |
"relay_mod, _ = relay.testing.resnet.get_workload(num_layers=50, batch_size=1, dtype=\"float32\")\n", | |
"\n", | |
"# translate the ResNet model from Relay to Relax\n", | |
"relax_mod = relay_translator.from_relay(relay_mod[\"main\"])\n", | |
"\n", | |
"# print the ResNet IRmodule got translated\n", | |
"R.parser.pretty_print(relax_mod)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "df9a585e", | |
"metadata": {}, | |
"source": [ | |
"## Tuning with MetaSchedule" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "a8c19a40", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"[09:22:10] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:96: Initializing Task #0: \"tir_matmul\"\n", | |
"[09:22:10] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:102: \n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, j0, k0 in tir.grid(32, 32, 32):\n", | |
" with tir.block():\n", | |
" i, j, k = tir.axis.remap(\"SSR\", [i0, j0, k0])\n", | |
" tir.reads(A_1[i, k], B_1[j, k])\n", | |
" tir.writes(C_1[i, j])\n", | |
" with tir.init():\n", | |
" C_1[i, j] = tir.float32(0)\n", | |
" C_1[i, j] = C_1[i, j] + A_1[i, k] * B_1[j, k]\n", | |
" \n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:106: Total 3 design space(s) generated\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #0:\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n", | |
" # body\n", | |
" with tir.block(\"root\"):\n", | |
" tir.reads()\n", | |
" tir.writes()\n", | |
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":64, \"meta_schedule.vectorize\":64})\n", | |
" C_global_1 = tir.alloc_buffer([32, 32], dtype=\"float32\")\n", | |
" for i0_0, j0_0, i0_1, j0_1 in tir.grid(1, 2, 4, 2):\n", | |
" for k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(4, 8, 1, 8, 1, 8):\n", | |
" with tir.block():\n", | |
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n", | |
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n", | |
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n", | |
" tir.reads(A_1[i, k], B_1[j, k])\n", | |
" tir.writes(C_global_1[i, j])\n", | |
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n", | |
" with tir.init():\n", | |
" C_global_1[i, j] = tir.float32(0)\n", | |
" C_global_1[i, j] = C_global_1[i, j] + A_1[i, k] * B_1[j, k]\n", | |
" for ax0, ax1 in tir.grid(8, 8):\n", | |
" with tir.block(\"C_global\"):\n", | |
" v0 = tir.axis.spatial(32, i0_1 * 8 + ax0)\n", | |
" v1 = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + ax1)\n", | |
" tir.reads(C_global_1[v0, v1])\n", | |
" tir.writes(C_1[v0, v1])\n", | |
" C_1[v0, v1] = C_global_1[v0, v1]\n", | |
" \n", | |
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n", | |
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n", | |
"l2, l3, l4 = sch.get_loops(block=b0)\n", | |
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n", | |
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n", | |
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n", | |
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n", | |
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n", | |
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n", | |
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n", | |
"b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope=\"global\")\n", | |
"sch.reverse_compute_at(block=b25, loop=l18, preserve_unit_loops=True)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n", | |
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v26)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #1:\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n", | |
" # body\n", | |
" with tir.block(\"root\"):\n", | |
" tir.reads()\n", | |
" tir.writes()\n", | |
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":0, \"meta_schedule.vectorize\":64})\n", | |
" C_global_1 = tir.alloc_buffer([32, 32], dtype=\"float32\")\n", | |
" for i0_0, j0_0 in tir.grid(1, 2):\n", | |
" for i0_1, j0_1, k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(4, 2, 4, 8, 1, 8, 1, 8):\n", | |
" with tir.block():\n", | |
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n", | |
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n", | |
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n", | |
" tir.reads(A_1[i, k], B_1[j, k])\n", | |
" tir.writes(C_global_1[i, j])\n", | |
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n", | |
" with tir.init():\n", | |
" C_global_1[i, j] = tir.float32(0)\n", | |
" C_global_1[i, j] = C_global_1[i, j] + A_1[i, k] * B_1[j, k]\n", | |
" for ax0, ax1 in tir.grid(32, 16):\n", | |
" with tir.block(\"C_global\"):\n", | |
" v0 = tir.axis.spatial(32, ax0)\n", | |
" v1 = tir.axis.spatial(32, j0_0 * 16 + ax1)\n", | |
" tir.reads(C_global_1[v0, v1])\n", | |
" tir.writes(C_1[v0, v1])\n", | |
" C_1[v0, v1] = C_global_1[v0, v1]\n", | |
" \n", | |
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n", | |
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n", | |
"l2, l3, l4 = sch.get_loops(block=b0)\n", | |
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n", | |
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n", | |
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n", | |
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n", | |
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n", | |
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n", | |
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n", | |
"b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope=\"global\")\n", | |
"sch.reverse_compute_at(block=b25, loop=l17, preserve_unit_loops=True)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n", | |
"v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v26)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #2:\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"], C_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_matmul\"})\n", | |
" # body\n", | |
" with tir.block(\"root\"):\n", | |
" tir.reads()\n", | |
" tir.writes()\n", | |
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":512, \"meta_schedule.vectorize\":64})\n", | |
" for i0_0, j0_0, i0_1, j0_1, k0_0, i0_2, j0_2, k0_1, i0_3, j0_3 in tir.grid(1, 2, 4, 2, 4, 8, 1, 8, 1, 8):\n", | |
" with tir.block():\n", | |
" i = tir.axis.spatial(32, i0_1 * 8 + i0_2)\n", | |
" j = tir.axis.spatial(32, j0_0 * 16 + j0_1 * 8 + j0_3)\n", | |
" k = tir.axis.reduce(32, k0_0 * 8 + k0_1)\n", | |
" tir.reads(A_1[i, k], B_1[j, k])\n", | |
" tir.writes(C_1[i, j])\n", | |
" tir.block_attr({\"meta_schedule.tiling_structure\":\"SSRSRS\"})\n", | |
" with tir.init():\n", | |
" C_1[i, j] = tir.float32(0)\n", | |
" C_1[i, j] = C_1[i, j] + A_1[i, k] * B_1[j, k]\n", | |
" \n", | |
"b0 = sch.get_block(name=\"\", func_name=\"main\")\n", | |
"b1 = sch.get_block(name=\"root\", func_name=\"main\")\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.tiling_structure\", ann_val=\"SSRSRS\")\n", | |
"l2, l3, l4 = sch.get_loops(block=b0)\n", | |
"v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[1, 4, 8, 1])\n", | |
"l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])\n", | |
"v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[2, 2, 1, 8])\n", | |
"l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])\n", | |
"v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[4, 8])\n", | |
"l23, l24 = sch.split(loop=l4, factors=[v21, v22])\n", | |
"sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.parallel\", ann_val=256)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n", | |
"v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=3)\n", | |
"sch.annotate(block_or_loop=b1, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v25)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:96: Initializing Task #1: \"tir_relu\"\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:102: \n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_relu\"})\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i, j in tir.grid(32, 32):\n", | |
" with tir.block():\n", | |
" vi, vj = tir.axis.remap(\"SS\", [i, j])\n", | |
" tir.reads(A_1[vi, vj])\n", | |
" tir.writes(B_1[vi, vj])\n", | |
" B_1[vi, vj] = tir.max(A_1[vi, vj], tir.float32(0))\n", | |
" \n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:106: Total 1 design space(s) generated\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:111: Design space #0:\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def main(A_1: tir.Buffer[(32, 32), \"float32\"], B_1: tir.Buffer[(32, 32), \"float32\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"tir_relu\"})\n", | |
" # body\n", | |
" with tir.block(\"root\"):\n", | |
" tir.reads()\n", | |
" tir.writes()\n", | |
" tir.block_attr({\"meta_schedule.parallel\":256, \"meta_schedule.unroll_explicit\":512, \"meta_schedule.vectorize\":64})\n", | |
" for i, j in tir.grid(32, 32):\n", | |
" with tir.block():\n", | |
" vi, vj = tir.axis.remap(\"SS\", [i, j])\n", | |
" tir.reads(A_1[vi, vj])\n", | |
" tir.writes(B_1[vi, vj])\n", | |
" B_1[vi, vj] = tir.max(A_1[vi, vj], tir.float32(0))\n", | |
" \n", | |
"b0 = sch.get_block(name=\"root\", func_name=\"main\")\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.parallel\", ann_val=256)\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.vectorize\", ann_val=64)\n", | |
"v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=3)\n", | |
"sch.annotate(block_or_loop=b0, ann_key=\"meta_schedule.unroll_explicit\", ann_val=v1)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:111: \n", | |
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
" 0 | tir_matmul | 65536 | 1 | N/A | N/A | N/A | 0 | \n", | |
" 1 | tir_relu | 1024 | 1 | N/A | N/A | N/A | 0 | \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
"Total trials: 0\n", | |
"Total latency (us): 0\n", | |
"\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:125: Scheduler picks Task #0: \"tir_matmul\"\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:656: Generating candidates......\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:658: Picked top 0 candidate(s) from database\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:494: Sample-Init-Population summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:660: Sampled 2048 candidate(s)\n", | |
"[09:22:11] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #0 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n", | |
"[09:22:12] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #1 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n", | |
"[09:22:12] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #2 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #3 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5c06ff48)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bf55038)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8d6e8)]: 0 failure(s)\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:595: Scores of the best 2 candidates:\n", | |
"[1 : 2]:\t0.9999 0.9999\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:664: Got 2 candidate(s) with evolutionary search\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:666: Sending 2 candidates(s) for measurement\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:32: Sending 2 sample(s) to builder\n", | |
"[09:22:13] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:53: Sending 2 sample(s) to runner\n", | |
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:125: Scheduler picks Task #1: \"tir_relu\"\n", | |
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:656: Generating candidates......\n", | |
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:658: Picked top 0 candidate(s) from database\n", | |
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:494: Sample-Init-Population summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n", | |
"[09:22:14] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:660: Sampled 2048 candidate(s)\n", | |
"[09:22:15] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #0 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n", | |
"[09:22:15] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #1 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n", | |
"[09:22:16] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #2 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:571: Evolve iter #3 done. Summary:\n", | |
"Postproc #0 [meta_schedule.DisallowDynamicLoop(0x562e5bc75bf8)]: 0 failure(s)\n", | |
"Postproc #1 [meta_schedule.RewriteParallelVectorizeUnroll(0x562e5bdae618)]: 0 failure(s)\n", | |
"Postproc #2 [meta_schedule.RewriteReductionBlock(0x562e5bd8c008)]: 0 failure(s)\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:595: Scores of the best 2 candidates:\n", | |
"[1 : 2]:\t0.8936 0.6800\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:664: Got 2 candidate(s) with evolutionary search\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/search_strategy/evolutionary_search.cc:666: Sending 2 candidates(s) for measurement\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:32: Sending 2 sample(s) to builder\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:53: Sending 2 sample(s) to runner\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #0: \"tir_matmul\"] Trial #0: GFLOPs: 3.4616. Time: 0.0189 ms. Best GFLOPs: 3.4616\n", | |
"[09:22:17] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #0: \"tir_matmul\"] Trial #1: GFLOPs: 4.3685. Time: 0.0150 ms. Best GFLOPs: 4.3685\n", | |
"/home/yuchenj/.local/lib/python3.7/site-packages/pkg_resources/__init__.py:119: PkgResourcesDeprecationWarning: 0.18ubuntu0.18.04.1 is an invalid version and will not be supported in a future release\n", | |
" PkgResourcesDeprecationWarning,\n", | |
"/home/yuchenj/.local/lib/python3.7/site-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated. See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html\n", | |
" warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:172: [Updated] Task #0: \"tir_matmul\"\n", | |
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
" 0 | tir_matmul | 65536 | 1 | 4.3685 | 15.0020 | 15.0020 | 2 | \n", | |
" 1 | tir_relu | 1024 | 1 | N/A | N/A | N/A | 0 | \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
"Total trials: 2\n", | |
"Total latency (us): 15.002\n", | |
"\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:149: Task #0 has finished. Remaining task(s): 1\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #1: \"tir_relu\"] Trial #0: GFLOPs: 0.0400. Time: 0.0256 ms. Best GFLOPs: 0.0400\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/measure_callback/echo_statistics.cc:52: [Task #1: \"tir_relu\"] Trial #1: GFLOPs: 0.0836. Time: 0.0122 ms. Best GFLOPs: 0.0836\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/gradient_based.cc:172: [Updated] Task #1: \"tir_relu\"\n", | |
" ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Terminated \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
" 0 | tir_matmul | 65536 | 1 | 4.3685 | 15.0020 | 15.0020 | 2 | Y \n", | |
" 1 | tir_relu | 1024 | 1 | 0.0836 | 12.2439 | 12.2439 | 2 | \n", | |
"----------------------------------------------------------------------------------------------------------------\n", | |
"Total trials: 4\n", | |
"Total latency (us): 27.2459\n", | |
"\n", | |
"[09:22:18] /home/yuchenj/dup/relax/src/meta_schedule/task_scheduler/task_scheduler.cc:149: Task #1 has finished. Remaining task(s): 0\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<tvm.nd.NDArray shape=(32, 32), cpu(0)>\n", | |
"array([[ 9.073223 , 9.206748 , 8.944179 , ..., 10.142985 , 9.609654 ,\n", | |
" 9.918516 ],\n", | |
" [ 8.463486 , 9.127991 , 8.411135 , ..., 8.769029 , 8.403664 ,\n", | |
" 9.8124075],\n", | |
" [ 7.24664 , 7.115847 , 7.343585 , ..., 8.223137 , 7.9694247,\n", | |
" 9.62473 ],\n", | |
" ...,\n", | |
" [ 8.291355 , 8.538074 , 7.597375 , ..., 8.382829 , 8.619025 ,\n", | |
" 9.292543 ],\n", | |
" [ 8.903806 , 9.442456 , 7.942852 , ..., 8.686344 , 7.853592 ,\n", | |
" 10.73006 ],\n", | |
" [ 7.6947894, 8.574124 , 7.977135 , ..., 8.067622 , 7.8615704,\n", | |
" 9.567776 ]], dtype=float32)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from tvm import meta_schedule as ms\n", | |
"import tempfile\n", | |
"from tvm.meta_schedule.testing import DummyDatabase\n", | |
"\n", | |
"\n", | |
"database = DummyDatabase()\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class InputModule:\n", | |
" @T.prim_func\n", | |
" def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:\n", | |
" T.func_attr({\"global_symbol\": \"tir_matmul\"})\n", | |
" m = T.var(\"int32\")\n", | |
" n = T.var(\"int32\")\n", | |
" k = T.var(\"int32\")\n", | |
" A = T.match_buffer(x, (32, 32))\n", | |
" B = T.match_buffer(y, (32, 32))\n", | |
" C = T.match_buffer(z, (32, 32))\n", | |
"\n", | |
" for (i0, j0, k0) in T.grid(32, 32, 32):\n", | |
" with T.block():\n", | |
" i, j, k = T.axis.remap(\"SSR\", [i0, j0, k0])\n", | |
" with T.init():\n", | |
" C[i, j] = 0.0\n", | |
" C[i, j] += A[i, k] * B[j, k]\n", | |
"\n", | |
" @T.prim_func\n", | |
" def tir_relu(x: T.handle, y: T.handle):\n", | |
" T.func_attr({\"global_symbol\": \"tir_relu\"})\n", | |
" m = T.var(\"int32\")\n", | |
" n = T.var(\"int32\")\n", | |
" A = T.match_buffer(x, (32, 32))\n", | |
" B = T.match_buffer(y, (32, 32))\n", | |
" for (i, j) in T.grid(32, 32):\n", | |
" with T.block():\n", | |
" vi, vj = T.axis.remap(\"SS\", [i, j])\n", | |
" B[vi, vj] = T.max(A[vi, vj], 0.0)\n", | |
"\n", | |
" @R.function\n", | |
" def main(x: Tensor((32, 32), \"float32\"), w: Tensor((32, 32), \"float32\")) -> Tensor:\n", | |
" with R.dataflow():\n", | |
" lv0 = R.call_tir(tir_matmul, (x, w), (32, 32), dtype=\"float32\")\n", | |
" lv1 = R.call_tir(tir_relu, (lv0), (32, 32), dtype=\"float32\")\n", | |
" relax.output(lv1)\n", | |
" return lv1\n", | |
"\n", | |
"mod = InputModule\n", | |
"target = Target(\"llvm --num-cores=16\")\n", | |
"dev = tvm.cpu()\n", | |
"database = DummyDatabase()\n", | |
"\n", | |
"with tempfile.TemporaryDirectory() as work_dir:\n", | |
" relax_ex = ms.tune_relax(\n", | |
" mod=mod,\n", | |
" target=target,\n", | |
" config=ms.EvolutionarySearchConfig(\n", | |
" num_trials_per_iter=2,\n", | |
" max_trials_per_task=4,\n", | |
" max_trials_global=4,\n", | |
" ),\n", | |
" work_dir=work_dir,\n", | |
" database=database,\n", | |
" )\n", | |
"\n", | |
"vm = relax.VirtualMachine(relax_ex, dev)\n", | |
"data = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev)\n", | |
"weight = tvm.nd.array(np.random.rand(32, 32).astype(np.float32), dev)\n", | |
"vm[\"main\"](data, weight)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cc0e909f", | |
"metadata": {}, | |
"source": [ | |
"## Relax minimum build" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "b3ab5492", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@tvm.register_func(\"test.vm.tile\")\n", | |
"def tile_packed(a, b):\n", | |
" b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "dfedf857", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"Original Relax Program\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n", | |
" # block 0\n", | |
" with relax.dataflow():\n", | |
" relax.match_shape(x, (n, m))\n", | |
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n", | |
" relax.output(y)\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"@tvm.script.ir_module\n", | |
"class MyIRModule:\n", | |
" @R.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor:\n", | |
" with R.dataflow():\n", | |
" R.match_shape(x, (n, m))\n", | |
" y = R.call_tir(\"test.vm.tile\", (x), (n, m * 2), dtype=\"float32\")\n", | |
" R.output(y)\n", | |
" return y\n", | |
"\n", | |
"# Original Relax Program\n", | |
"print(\"======================\")\n", | |
"print(\"Original Relax Program\\n\")\n", | |
"mod = MyIRModule\n", | |
"code = R.parser.astext(mod)\n", | |
"print(code)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "9fb478f3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS0: To Non Dataflow\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n", | |
" # block 0\n", | |
" relax.match_shape(x, (n, m))\n", | |
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# ToNonDataflow Pass\n", | |
"print(\"======================\")\n", | |
"print(\"PASS0: To Non Dataflow\\n\")\n", | |
"mod = relax.transform.ToNonDataflow()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "4221356d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS1: CallDPS Rewrite\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n", | |
" # block 0\n", | |
" relax.match_shape(x, (n, m))\n", | |
" alloc = relax.builtin.alloc_tensor((n, (m * 2)), dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.AllocTensorAttrs\")\n", | |
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# CallDPS Rewrite\n", | |
"print(\"======================\")\n", | |
"print(\"PASS1: CallDPS Rewrite\\n\")\n", | |
"mod = relax.transform.CallTIRRewrite()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "f13f5dc5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS2: Memory Lower\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n", | |
" # block 0\n", | |
" relax.match_shape(x, (n, m))\n", | |
" storage = relax.vm.builtin.alloc_storage((((n * (m * 2)) * 4),), dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n", | |
" tensor = relax.vm.builtin.alloc_tensor(storage, (n, (m * 2)), offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n", | |
" alloc = tensor\n", | |
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# Memory Lower\n", | |
"print(\"======================\")\n", | |
"print(\"PASS2: Memory Lower\\n\")\n", | |
"mod = relax.transform.VMMemoryLower()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "82592858", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS3: Shape Lower\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def shape_func(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"shape_func\"})\n", | |
" # body\n", | |
" H_1[2] = H_1[0] * (H_1[1] * tir.int64(2)) * tir.int64(4)\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def shape_func1(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"shape_func1\"})\n", | |
" # body\n", | |
" H_1[0] = H_1[0]\n", | |
" H_1[3] = H_1[1] * tir.int64(2)\n", | |
" \n", | |
" @relax.function\n", | |
" def foo(x: Tensor(_, \"float32\")) -> Tensor(_, _):\n", | |
" # block 0\n", | |
" shape_heap: Tensor((4,), \"int64\") = relax.call_packed(\"vm.builtin.alloc_shape_heap\", (4,))\n", | |
" # block 1\n", | |
" sh = relax.call_packed(\"vm.builtin.shape_of\", x)\n", | |
" gv = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" _ = shape_func(shape_heap)\n", | |
" sh1 = relax.vm.builtin.load_shape(shape_heap, indices=[2], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" storage = relax.vm.builtin.alloc_storage(sh1, dtype=\"float32\", runtime_device_index=0, attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n", | |
" _1 = shape_func1(shape_heap)\n", | |
" sh2 = relax.vm.builtin.load_shape(shape_heap, indices=[0, 3], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" tensor = relax.vm.builtin.alloc_tensor(storage, sh2, offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n", | |
" alloc = tensor\n", | |
" _2 = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# Shape Lower\n", | |
"print(\"======================\")\n", | |
"print(\"PASS3: Shape Lower\\n\")\n", | |
"mod = relax.transform.VMShapeLower()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "63e9ad0a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"Build & Execute\n", | |
"input: [[0.0434758 0.54549795 0.01151199 0.26422125]\n", | |
" [0.06118515 0.20250504 0.9074851 0.0011075 ]\n", | |
" [0.09346122 0.71868443 0.22026683 0.20938304]]\n", | |
"output: [[0.0434758 0.54549795 0.01151199 0.26422125 0.0434758 0.54549795\n", | |
" 0.01151199 0.26422125]\n", | |
" [0.06118515 0.20250504 0.9074851 0.0011075 0.06118515 0.20250504\n", | |
" 0.9074851 0.0011075 ]\n", | |
" [0.09346122 0.71868443 0.22026683 0.20938304 0.09346122 0.71868443\n", | |
" 0.22026683 0.20938304]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# Build & Execute\n", | |
"print(\"======================\")\n", | |
"print(\"Build & Execute\")\n", | |
"\n", | |
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())\n", | |
"\n", | |
"shape = (3, 4)\n", | |
"inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))\n", | |
"out = vm[\"foo\"](inp)\n", | |
"print(\"input: \", inp)\n", | |
"print(\"output: \", out)\n", | |
"np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), out.asnumpy())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment