Skip to content

Instantly share code, notes, and snippets.

@YuchenJin
Last active November 20, 2022 19:19
Show Gist options
  • Save YuchenJin/43729fc6f5ddf30e339258a7de5155c9 to your computer and use it in GitHub Desktop.
Save YuchenJin/43729fc6f5ddf30e339258a7de5155c9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d5783aee",
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations # must import to defer parsing of annotations\n",
"import os\n",
"import numpy as np\n",
"import tvm\n",
"from tvm.relay import Call\n",
"from tvm import relax, tir, topi\n",
"from tvm.runtime import container\n",
"from tvm.relax.testing import nn\n",
"\n",
"import tvm.script\n",
"from tvm.script import tir as T, relax as R"
]
},
{
"cell_type": "markdown",
"id": "62b4198a",
"metadata": {},
"source": [
"## Build and run a neural network in Relax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd0cc6e9",
"metadata": {},
"outputs": [],
"source": [
"builder = relax.BlockBuilder()\n",
"\n",
"input_size = 784\n",
"hidden_sizes = [128, 32]\n",
"output_size = 10"
]
},
{
"cell_type": "markdown",
"id": "068cded9",
"metadata": {},
"source": [
"Build a three linear-layer neural network for a classification task.\n",
"A neural network is a `nn.Module` that can consist of other modules (layers). This nested structure allows for building complex models easily."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a12e45da",
"metadata": {},
"outputs": [],
"source": [
"with builder.function(name=\"main\"):\n",
" model = nn.Sequential(\n",
" nn.Linear(input_size, hidden_sizes[0]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n",
" nn.ReLU(),\n",
" nn.Linear(hidden_sizes[1], output_size),\n",
" nn.LogSoftmax(),\n",
" )\n",
" # n is a symbolic variable to represent a dynamic batch size\n",
" n = tir.Var(\"n\", \"int64\")\n",
" data = nn.Placeholder((n, input_size), name=\"data\")\n",
" output = model(data)\n",
" params = [data] + model.parameters()\n",
" builder.emit_func_output(output, params=params) "
]
},
{
"cell_type": "markdown",
"id": "abde5a76",
"metadata": {},
"source": [
"Get and print the IRmodule being built."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "54f1bb9e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def matmul1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128, 32), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 32], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 32, 128):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n",
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n",
" @tir.prim_func\n",
" def relu1(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 32], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 32):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @relax.function\n",
" def main(data: Tensor[(n, 784), \"float32\"], linear_weight: Tensor[(784, 128), \"float32\"], linear_bias: Tensor[(128,), \"float32\"], linear_weight1: Tensor[(128, 32), \"float32\"], linear_bias1: Tensor[(32,), \"float32\"], linear_weight2: Tensor[(32, 10), \"float32\"], linear_bias2: Tensor[(10,), \"float32\"]) -> Tensor[_, \"float32\"]:\n",
" # block 0\n",
" gv = relax.call_tir(matmul, (data, linear_weight), (n, 128), dtype=\"float32\")\n",
" gv1 = relax.call_tir(add, (gv, linear_bias), (n, 128), dtype=\"float32\")\n",
" gv2 = relax.call_tir(relu, (gv1,), (n, 128), dtype=\"float32\")\n",
" gv3 = relax.call_tir(matmul1, (gv2, linear_weight1), (n, 32), dtype=\"float32\")\n",
" gv4 = relax.call_tir(add1, (gv3, linear_bias1), (n, 32), dtype=\"float32\")\n",
" gv5 = relax.call_tir(relu1, (gv4,), (n, 32), dtype=\"float32\")\n",
" gv6 = relax.call_tir(matmul2, (gv5, linear_weight2), (n, 10), dtype=\"float32\")\n",
" gv7 = relax.call_tir(add2, (gv6, linear_bias2), (n, 10), dtype=\"float32\")\n",
" gv8 = relax.call_tir(log_softmax, (gv7,), (n, 10), dtype=\"float32\")\n",
" return gv8\n",
" \n",
" @tir.prim_func\n",
" def add(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 128):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def add2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(10,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def add1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(32,), \"float32\"], var_T_add_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n",
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 32], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 32):\n",
" with tir.block(\"T_add\"):\n",
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n",
" tir.writes(T_add_1[ax0, ax1])\n",
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n",
" \n",
" @tir.prim_func\n",
" def matmul2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(32, 10), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul2\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 10, 32):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n",
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n",
" @tir.prim_func\n",
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, 128):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @tir.prim_func\n",
" def matmul(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(784, 128), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 784], dtype=\"float32\")\n",
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 128], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1, i2 in tir.grid(n_1, 128, 784):\n",
" with tir.block(\"T_matmul\"):\n",
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n",
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n",
" tir.writes(T_matmul_1[ax0, ax1])\n",
" with tir.init():\n",
" T_matmul_1[ax0, ax1] = tir.float32(0)\n",
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n",
" \n",
" @tir.prim_func\n",
" def log_softmax(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"log_softmax\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n",
" compute_3 = tir.match_buffer(var_compute_1, [n_1, 10], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" compute_4 = tir.alloc_buffer([n_1], dtype=\"float32\")\n",
" compute_5 = tir.alloc_buffer([n_1], dtype=\"float32\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute\"):\n",
" i, k = tir.axis.remap(\"SR\", [i0, i1])\n",
" tir.reads(compute_4[i], rxplaceholder_1[i, k])\n",
" tir.writes(compute_4[i])\n",
" with tir.init():\n",
" compute_4[i] = tir.float32(-3.4028234663852886e+38)\n",
" compute_4[i] = tir.max(compute_4[i], rxplaceholder_1[i, k])\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute\"):\n",
" i, k = tir.axis.remap(\"SR\", [i0, i1])\n",
" tir.reads(compute_5[i], rxplaceholder_1[i, k], compute_4[i])\n",
" tir.writes(compute_5[i])\n",
" with tir.init():\n",
" compute_5[i] = tir.float32(0)\n",
" compute_5[i] = compute_5[i] + tir.exp(rxplaceholder_1[i, k] - compute_4[i], dtype=\"float32\")\n",
" for i0, i1 in tir.grid(n_1, 10):\n",
" with tir.block(\"compute\"):\n",
" i, j = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_1[i, j], compute_4[i], compute_5[i])\n",
" tir.writes(compute_3[i, j])\n",
" compute_3[i, j] = rxplaceholder_1[i, j] - compute_4[i] - tir.log(compute_5[i], dtype=\"float32\")\n",
" \n"
]
}
],
"source": [
"mod = builder.get()\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "markdown",
"id": "158875f7",
"metadata": {},
"source": [
"Build the model and create Relax VM."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b6e7a8cd",
"metadata": {},
"outputs": [],
"source": [
"# build and create vm executor\n",
"target = tvm.target.Target(\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())"
]
},
{
"cell_type": "markdown",
"id": "7679f077",
"metadata": {},
"source": [
"Run the model on Relax VM."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1e530019",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]]\n"
]
}
],
"source": [
"# init parameters\n",
"params = nn.init_params(mod)\n",
"# the input data has a minibatch size of 3\n",
"data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))\n",
"\n",
"res = vm[\"main\"](data, *params)\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dfe290fb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n",
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n",
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]]\n"
]
}
],
"source": [
"# the input data has a minibatch size of 5\n",
"data = tvm.nd.array(np.random.rand(5, input_size).astype(np.float32))\n",
"\n",
"res = vm[\"main\"](data, *params)\n",
"print(res)"
]
},
{
"cell_type": "markdown",
"id": "23b4c322",
"metadata": {},
"source": [
"Define a layer/network with emit_te using the nn.Module interface."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "eb687277",
"metadata": {},
"outputs": [],
"source": [
"class Linear(nn.Module):\n",
" \"\"\"Applies a linear transformation to the input data: :math:`y = xA + b`.\"\"\"\n",
"\n",
" def __init__(self, in_features, out_features, bias=True):\n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.weight = Parameter((in_features, out_features), name=\"linear_weight\")\n",
" if bias:\n",
" self.bias = Parameter((out_features,), name=\"linear_bias\")\n",
" else:\n",
" self.bias = None\n",
"\n",
" def forward(self, input: relax.Expr) -> relax.Var:\n",
" y = emit_te(topi.matmul, input, self.weight)\n",
" if self.bias is not None:\n",
" y = emit_te(topi.add, y, self.bias)\n",
" return y"
]
},
{
"cell_type": "markdown",
"id": "89e095a9",
"metadata": {},
"source": [
"## TE/TOPI Integration"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7ca3cfdb",
"metadata": {},
"outputs": [],
"source": [
"def build_mlp(data, weight):\n",
" bb = relax.BlockBuilder()\n",
"\n",
" with bb.function(\"mlp\", [data, weight]):\n",
" gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)\n",
" gv1 = bb.emit_te(topi.nn.relu, gv0)\n",
" bb.emit_func_output(gv1)\n",
"\n",
" mod = bb.get()\n",
" return mod"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c5c1e8f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def matmul(var_rxplaceholder_2: tir.handle, var_rxplaceholder_3: tir.handle, var_C_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n",
" m_1 = tir.var(\"int64\")\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_2 = tir.match_buffer(var_rxplaceholder_2, [n_1, m_1], dtype=\"float32\")\n",
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_3, [m_1, n_1], dtype=\"float32\")\n",
" C_1 = tir.match_buffer(var_C_1, [n_1, n_1], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" tir.evaluate(tir.tvm_call_packed(\"tvm.contrib.cblas.matmul\", tir.tvm_stack_make_array(rxplaceholder_2.data, tir.tvm_stack_make_shape(n_1, m_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(rxplaceholder_3.data, tir.tvm_stack_make_shape(m_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(C_1.data, tir.tvm_stack_make_shape(n_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), False, False, dtype=\"int32\"))\n",
" \n",
" @tir.prim_func\n",
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n",
" n_1 = tir.var(\"int64\")\n",
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, n_1], dtype=\"float32\")\n",
" compute_1 = tir.match_buffer(var_compute_1, [n_1, n_1], dtype=\"float32\")\n",
" # body\n",
" # with tir.block(\"root\")\n",
" for i0, i1 in tir.grid(n_1, n_1):\n",
" with tir.block(\"compute\"):\n",
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n",
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n",
" tir.writes(compute_1[i0_2, i1_2])\n",
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n",
" \n",
" @relax.function\n",
" def mlp(data: Tensor[(n, m), \"float32\"], weight: Tensor[(m, n), \"float32\"]) -> Tensor[_, \"float32\"]:\n",
" # block 0\n",
" gv = relax.call_tir(matmul, (data, weight), (n, n), dtype=\"float32\")\n",
" gv1 = relax.call_tir(relu, (gv,), (n, n), dtype=\"float32\")\n",
" return gv1\n",
" \n"
]
}
],
"source": [
"# symbolic dimensions\n",
"n, m = tir.Var(\"n\", \"int64\"), tir.Var(\"m\", \"int64\")\n",
"\n",
"# create data and weight variables\n",
"data = relax.Var(\"data\", [n, m], relax.DynTensorType(2, \"float32\"))\n",
"weight = relax.Var(\"weight\", [m, n], relax.DynTensorType(2, \"float32\"))\n",
"\n",
"# construct a mlp model\n",
"mod = build_mlp(data, weight)\n",
"print(R.parser.astext(mod))\n",
"\n",
"# build and create vm executor\n",
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8102fabf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8.556568 6.87848 8.625866 9.600383 8.317718 8.227419\n",
" 9.219838 8.9676075 7.623468 8.23633 7.0235295 8.507209\n",
" 9.027352 7.290115 8.554563 7.6762805]\n",
" [ 6.663047 7.8030357 8.069457 8.890441 7.939424 8.519871\n",
" 9.274455 8.174765 7.9989533 7.5531206 8.587893 8.444625\n",
" 9.052873 7.5767756 8.191356 7.0712266]\n",
" [ 7.403376 7.207227 8.18778 8.746606 9.113302 7.779246\n",
" 8.831476 7.4629865 6.509424 7.559846 7.511036 7.851724\n",
" 8.445092 6.7876825 8.292095 7.063372 ]\n",
" [ 7.427354 7.0812135 7.5834956 9.232357 7.942724 8.243788\n",
" 8.2506075 9.12499 7.663587 8.29403 7.9033995 7.4146137\n",
" 8.90403 7.4565487 7.7196827 7.277466 ]\n",
" [ 9.678751 9.520125 11.651919 11.137583 11.08173 10.491489\n",
" 11.80332 10.800545 10.412642 11.5577135 9.599909 10.551059\n",
" 11.7693205 9.735142 11.261857 10.543016 ]\n",
" [ 6.8073425 5.918911 7.9557896 8.429453 7.6718554 7.034882\n",
" 8.698473 7.8499546 7.649053 7.733601 6.118468 7.5574713\n",
" 7.795342 6.8316064 7.1764045 7.349864 ]\n",
" [ 7.958915 6.212535 8.084226 8.112065 8.320807 8.080606\n",
" 9.801969 8.86682 9.270056 8.871507 7.9004436 7.482185\n",
" 8.870282 7.91054 8.001995 7.3347917]\n",
" [ 5.7960997 5.928576 7.2521763 6.63397 6.830027 5.9787736\n",
" 7.379688 7.4080906 4.9670863 6.8002133 5.4344077 5.761625\n",
" 7.6015153 5.913869 6.8978233 6.2877145]\n",
" [ 8.300703 7.4624095 8.288098 8.837615 9.562206 8.780277\n",
" 9.527164 8.587998 7.1991935 7.7677937 7.8223143 7.735983\n",
" 8.476541 7.170992 8.453187 8.275 ]\n",
" [ 8.490779 9.189245 10.063091 10.9346 9.817375 10.8089695\n",
" 11.861148 11.527818 9.8396015 11.470564 8.860413 9.731347\n",
" 11.540259 8.995022 10.125538 9.088766 ]\n",
" [ 7.4908876 7.6106896 8.144763 8.726535 8.755715 8.209023\n",
" 8.834324 7.8766665 6.288496 8.164509 7.5905747 8.019636\n",
" 9.01565 6.7951274 8.550508 7.674556 ]\n",
" [ 6.193086 5.209443 7.299016 6.642886 6.5502434 6.89572\n",
" 8.01048 7.1452103 6.3567533 7.611481 6.6744666 5.2749963\n",
" 6.601325 5.878633 7.0573955 6.89116 ]\n",
" [ 7.7727585 6.6690907 9.18409 8.672666 7.804694 8.393168\n",
" 10.026285 8.695979 7.850456 8.519929 7.7823 7.553142\n",
" 8.590052 7.1167426 7.9311466 8.562405 ]\n",
" [ 8.467529 8.3538885 10.179802 9.085586 9.400566 7.721974\n",
" 10.036323 8.895725 7.5192246 8.652134 7.3148284 8.582638\n",
" 9.310177 8.01297 9.493657 8.240374 ]\n",
" [ 8.044313 6.7396564 8.170473 8.387967 8.349282 8.374153\n",
" 9.131376 8.743106 7.3198457 8.30513 6.786698 8.167014\n",
" 8.164466 6.805107 8.06095 7.866236 ]\n",
" [ 7.4279284 7.7251635 8.799779 9.469591 8.381312 7.999055\n",
" 9.07975 8.277402 6.9867654 7.9436297 8.238659 8.03124\n",
" 9.060811 6.9540024 9.092028 7.9286 ]]\n"
]
}
],
"source": [
"# run the mlp model on relax vm\n",
"data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))\n",
"weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))\n",
"res = vm[\"mlp\"](data, weight)\n",
"print(res)"
]
},
{
"cell_type": "markdown",
"id": "cc0e909f",
"metadata": {},
"source": [
"## Relax compilation and execution workflow"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b3ab5492",
"metadata": {},
"outputs": [],
"source": [
"@tvm.register_func(\"test.vm.tile\")\n",
"def tile_packed(a, b):\n",
" b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2)))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "dfedf857",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"Original Relax Program\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n",
" # block 0\n",
" with relax.dataflow():\n",
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n",
" relax.output(y)\n",
" return y\n",
" \n"
]
}
],
"source": [
"src = \"\"\"@tvm.script.ir_module\n",
"class InputModule:\n",
" @R.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor:\n",
" with relax.dataflow():\n",
" y = R.call_tir(\"test.vm.tile\", (x), (n, m * 2), dtype=\"float32\")\n",
" relax.output(y)\n",
" return y\n",
"\"\"\"\n",
"\n",
"# Original Relax Program\n",
"print(\"======================\")\n",
"print(\"Original Relax Program\\n\")\n",
"mod = R.parser.from_source(src)\n",
"code = R.parser.astext(mod)\n",
"print(code)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "9fb478f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS0: To Non Dataflow\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n",
" # block 0\n",
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n",
" return y\n",
" \n"
]
}
],
"source": [
"# ToNonDataflow Pass\n",
"print(\"======================\")\n",
"print(\"PASS0: To Non Dataflow\\n\")\n",
"mod = relax.transform.ToNonDataflow()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4221356d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS1: CallDPS Rewrite\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n",
" # block 0\n",
" alloc = relax.builtin.alloc_tensor((n, (m * 2)), dtype=\"float32\", attrs_type_key=\"relax.attrs.AllocTensorAttrs\")\n",
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# CallDPS Rewrite\n",
"print(\"======================\")\n",
"print(\"PASS1: CallDPS Rewrite\\n\")\n",
"mod = relax.transform.CallTIRRewrite()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f13f5dc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS2: Memory Lower\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @relax.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n",
" # block 0\n",
" storage = relax.vm.builtin.alloc_storage((((n * (m * 2)) * 4),), device_type=1, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n",
" tensor = relax.vm.builtin.alloc_tensor(storage, (n, (m * 2)), offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n",
" alloc = tensor\n",
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# Memory Lower\n",
"print(\"======================\")\n",
"print(\"PASS2: Memory Lower\\n\")\n",
"mod = relax.transform.VMMemoryLower()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "82592858",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"PASS3: Shape Lower\n",
"\n",
"@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def shape_func(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"shape_func\"})\n",
" # body\n",
" H_1[2] = H_1[0] * (H_1[1] * tir.int64(2)) * tir.int64(4)\n",
" \n",
" @tir.prim_func\n",
" def shape_func1(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({\"global_symbol\": \"shape_func1\"})\n",
" # body\n",
" H_1[0] = H_1[0]\n",
" H_1[3] = H_1[1] * tir.int64(2)\n",
" \n",
" @relax.function\n",
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n",
" # block 0\n",
" shape_heap: Tensor[(4,), \"int64\"] = relax.call_packed(\"vm.builtin.alloc_shape_heap\", (4,))\n",
" sh = relax.call_packed(\"vm.builtin.shape_of\", x)\n",
" gv = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" # block 1\n",
" _ = shape_func(shape_heap)\n",
" sh1 = relax.vm.builtin.load_shape(shape_heap, indices=[2], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" storage = relax.vm.builtin.alloc_storage(sh1, device_type=1, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n",
" _1 = shape_func1(shape_heap)\n",
" sh2 = relax.vm.builtin.load_shape(shape_heap, indices=[0, 3], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n",
" tensor = relax.vm.builtin.alloc_tensor(storage, sh2, offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n",
" alloc = tensor\n",
" _2 = relax.call_packed(\"test.vm.tile\", x, alloc)\n",
" y = alloc\n",
" return y\n",
" \n"
]
}
],
"source": [
"# Shape Lower\n",
"print(\"======================\")\n",
"print(\"PASS3: Shape Lower\\n\")\n",
"mod = relax.transform.VMShapeLower()(mod)\n",
"print(R.parser.astext(mod))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "63e9ad0a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================\n",
"Build & Execute\n",
"input: [[0.46070135 0.1330114 0.8229871 0.3692288 ]\n",
" [0.91415226 0.24083617 0.25455388 0.40971643]\n",
" [0.03785264 0.329651 0.40605137 0.05099611]]\n",
"output: [[0.46070135 0.1330114 0.8229871 0.3692288 0.46070135 0.1330114\n",
" 0.8229871 0.3692288 ]\n",
" [0.91415226 0.24083617 0.25455388 0.40971643 0.91415226 0.24083617\n",
" 0.25455388 0.40971643]\n",
" [0.03785264 0.329651 0.40605137 0.05099611 0.03785264 0.329651\n",
" 0.40605137 0.05099611]]\n"
]
}
],
"source": [
"# Build & Execute\n",
"print(\"======================\")\n",
"print(\"Build & Execute\")\n",
"\n",
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
"ex = relax.vm.build(mod, target)\n",
"vm = relax.VirtualMachine(ex, tvm.cpu())\n",
"\n",
"shape = (3, 4)\n",
"inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))\n",
"out = vm[\"foo\"](inp)\n",
"print(\"input: \", inp)\n",
"print(\"output: \", out)\n",
"np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), out.asnumpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d55995e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@ganler
Copy link

ganler commented May 18, 2022

Block [8]: emit_te should be nn.emit_te. Parameter should be nn.Parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment