Last active
November 20, 2022 19:19
-
-
Save YuchenJin/43729fc6f5ddf30e339258a7de5155c9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "d5783aee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import annotations # must import to defer parsing of annotations\n", | |
"import os\n", | |
"import numpy as np\n", | |
"import tvm\n", | |
"from tvm.relay import Call\n", | |
"from tvm import relax, tir, topi\n", | |
"from tvm.runtime import container\n", | |
"from tvm.relax.testing import nn\n", | |
"\n", | |
"import tvm.script\n", | |
"from tvm.script import tir as T, relax as R" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "62b4198a", | |
"metadata": {}, | |
"source": [ | |
"## Build and run a neural network in Relax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "fd0cc6e9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"builder = relax.BlockBuilder()\n", | |
"\n", | |
"input_size = 784\n", | |
"hidden_sizes = [128, 32]\n", | |
"output_size = 10" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "068cded9", | |
"metadata": {}, | |
"source": [ | |
"Build a three linear-layer neural network for a classification task.\n", | |
"A neural network is a `nn.Module` that can consist of other modules (layers). This nested structure allows for building complex models easily." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "a12e45da", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with builder.function(name=\"main\"):\n", | |
" model = nn.Sequential(\n", | |
" nn.Linear(input_size, hidden_sizes[0]),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_sizes[0], hidden_sizes[1]),\n", | |
" nn.ReLU(),\n", | |
" nn.Linear(hidden_sizes[1], output_size),\n", | |
" nn.LogSoftmax(),\n", | |
" )\n", | |
" # n is a symbolic variable to represent a dynamic batch size\n", | |
" n = tir.Var(\"n\", \"int64\")\n", | |
" data = nn.Placeholder((n, input_size), name=\"data\")\n", | |
" output = model(data)\n", | |
" params = [data] + model.parameters()\n", | |
" builder.emit_func_output(output, params=params) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "abde5a76", | |
"metadata": {}, | |
"source": [ | |
"Get and print the IRmodule being built." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "54f1bb9e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def matmul1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128, 32), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 32], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 32, 128):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", | |
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu1(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 32], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 32):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @relax.function\n", | |
" def main(data: Tensor[(n, 784), \"float32\"], linear_weight: Tensor[(784, 128), \"float32\"], linear_bias: Tensor[(128,), \"float32\"], linear_weight1: Tensor[(128, 32), \"float32\"], linear_bias1: Tensor[(32,), \"float32\"], linear_weight2: Tensor[(32, 10), \"float32\"], linear_bias2: Tensor[(10,), \"float32\"]) -> Tensor[_, \"float32\"]:\n", | |
" # block 0\n", | |
" gv = relax.call_tir(matmul, (data, linear_weight), (n, 128), dtype=\"float32\")\n", | |
" gv1 = relax.call_tir(add, (gv, linear_bias), (n, 128), dtype=\"float32\")\n", | |
" gv2 = relax.call_tir(relu, (gv1,), (n, 128), dtype=\"float32\")\n", | |
" gv3 = relax.call_tir(matmul1, (gv2, linear_weight1), (n, 32), dtype=\"float32\")\n", | |
" gv4 = relax.call_tir(add1, (gv3, linear_bias1), (n, 32), dtype=\"float32\")\n", | |
" gv5 = relax.call_tir(relu1, (gv4,), (n, 32), dtype=\"float32\")\n", | |
" gv6 = relax.call_tir(matmul2, (gv5, linear_weight2), (n, 10), dtype=\"float32\")\n", | |
" gv7 = relax.call_tir(add2, (gv6, linear_bias2), (n, 10), dtype=\"float32\")\n", | |
" gv8 = relax.call_tir(log_softmax, (gv7,), (n, 10), dtype=\"float32\")\n", | |
" return gv8\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(128,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 128):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(10,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add2\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def add1(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(32,), \"float32\"], var_T_add_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"add1\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n", | |
" T_add_1 = tir.match_buffer(var_T_add_1, [n_1, 32], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 32):\n", | |
" with tir.block(\"T_add\"):\n", | |
" ax0, ax1 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_3[ax0, ax1], rxplaceholder_2[ax1])\n", | |
" tir.writes(T_add_1[ax0, ax1])\n", | |
" T_add_1[ax0, ax1] = rxplaceholder_3[ax0, ax1] + rxplaceholder_2[ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def matmul2(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(32, 10), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul2\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 32], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 10, 32):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", | |
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 128], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, 128):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def matmul(var_rxplaceholder_1: tir.handle, rxplaceholder_2: tir.Buffer[(784, 128), \"float32\"], var_T_matmul_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_1, [n_1, 784], dtype=\"float32\")\n", | |
" T_matmul_1 = tir.match_buffer(var_T_matmul_1, [n_1, 128], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1, i2 in tir.grid(n_1, 128, 784):\n", | |
" with tir.block(\"T_matmul\"):\n", | |
" ax0, ax1, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", | |
" tir.reads(T_matmul_1[ax0, ax1], rxplaceholder_3[ax0, k], rxplaceholder_2[k, ax1])\n", | |
" tir.writes(T_matmul_1[ax0, ax1])\n", | |
" with tir.init():\n", | |
" T_matmul_1[ax0, ax1] = tir.float32(0)\n", | |
" T_matmul_1[ax0, ax1] = T_matmul_1[ax0, ax1] + rxplaceholder_3[ax0, k] * rxplaceholder_2[k, ax1]\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def log_softmax(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"log_softmax\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, 10], dtype=\"float32\")\n", | |
" compute_3 = tir.match_buffer(var_compute_1, [n_1, 10], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" compute_4 = tir.alloc_buffer([n_1], dtype=\"float32\")\n", | |
" compute_5 = tir.alloc_buffer([n_1], dtype=\"float32\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute\"):\n", | |
" i, k = tir.axis.remap(\"SR\", [i0, i1])\n", | |
" tir.reads(compute_4[i], rxplaceholder_1[i, k])\n", | |
" tir.writes(compute_4[i])\n", | |
" with tir.init():\n", | |
" compute_4[i] = tir.float32(-3.4028234663852886e+38)\n", | |
" compute_4[i] = tir.max(compute_4[i], rxplaceholder_1[i, k])\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute\"):\n", | |
" i, k = tir.axis.remap(\"SR\", [i0, i1])\n", | |
" tir.reads(compute_5[i], rxplaceholder_1[i, k], compute_4[i])\n", | |
" tir.writes(compute_5[i])\n", | |
" with tir.init():\n", | |
" compute_5[i] = tir.float32(0)\n", | |
" compute_5[i] = compute_5[i] + tir.exp(rxplaceholder_1[i, k] - compute_4[i], dtype=\"float32\")\n", | |
" for i0, i1 in tir.grid(n_1, 10):\n", | |
" with tir.block(\"compute\"):\n", | |
" i, j = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_1[i, j], compute_4[i], compute_5[i])\n", | |
" tir.writes(compute_3[i, j])\n", | |
" compute_3[i, j] = rxplaceholder_1[i, j] - compute_4[i] - tir.log(compute_5[i], dtype=\"float32\")\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"mod = builder.get()\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "158875f7", | |
"metadata": {}, | |
"source": [ | |
"Build the model and create Relax VM." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b6e7a8cd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# build and create vm executor\n", | |
"target = tvm.target.Target(\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7679f077", | |
"metadata": {}, | |
"source": [ | |
"Run the model on Relax VM." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "1e530019", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# init parameters\n", | |
"params = nn.init_params(mod)\n", | |
"# the input data has a minibatch size of 3\n", | |
"data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))\n", | |
"\n", | |
"res = vm[\"main\"](data, *params)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "dfe290fb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]\n", | |
" [-2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851 -2.3025851\n", | |
" -2.3025851 -2.3025851 -2.3025851 -2.3025851]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# the input data has a minibatch size of 5\n", | |
"data = tvm.nd.array(np.random.rand(5, input_size).astype(np.float32))\n", | |
"\n", | |
"res = vm[\"main\"](data, *params)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "23b4c322", | |
"metadata": {}, | |
"source": [ | |
"Define a layer/network with emit_te using the nn.Module interface." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "eb687277", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Linear(nn.Module):\n", | |
" \"\"\"Applies a linear transformation to the input data: :math:`y = xA + b`.\"\"\"\n", | |
"\n", | |
" def __init__(self, in_features, out_features, bias=True):\n", | |
" self.in_features = in_features\n", | |
" self.out_features = out_features\n", | |
" self.weight = Parameter((in_features, out_features), name=\"linear_weight\")\n", | |
" if bias:\n", | |
" self.bias = Parameter((out_features,), name=\"linear_bias\")\n", | |
" else:\n", | |
" self.bias = None\n", | |
"\n", | |
" def forward(self, input: relax.Expr) -> relax.Var:\n", | |
" y = emit_te(topi.matmul, input, self.weight)\n", | |
" if self.bias is not None:\n", | |
" y = emit_te(topi.add, y, self.bias)\n", | |
" return y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "89e095a9", | |
"metadata": {}, | |
"source": [ | |
"## TE/TOPI Integration" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "7ca3cfdb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def build_mlp(data, weight):\n", | |
" bb = relax.BlockBuilder()\n", | |
"\n", | |
" with bb.function(\"mlp\", [data, weight]):\n", | |
" gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)\n", | |
" gv1 = bb.emit_te(topi.nn.relu, gv0)\n", | |
" bb.emit_func_output(gv1)\n", | |
"\n", | |
" mod = bb.get()\n", | |
" return mod" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "c5c1e8f2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def matmul(var_rxplaceholder_2: tir.handle, var_rxplaceholder_3: tir.handle, var_C_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"matmul\", \"tir.noalias\": True})\n", | |
" m_1 = tir.var(\"int64\")\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_2 = tir.match_buffer(var_rxplaceholder_2, [n_1, m_1], dtype=\"float32\")\n", | |
" rxplaceholder_3 = tir.match_buffer(var_rxplaceholder_3, [m_1, n_1], dtype=\"float32\")\n", | |
" C_1 = tir.match_buffer(var_C_1, [n_1, n_1], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" tir.evaluate(tir.tvm_call_packed(\"tvm.contrib.cblas.matmul\", tir.tvm_stack_make_array(rxplaceholder_2.data, tir.tvm_stack_make_shape(n_1, m_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(rxplaceholder_3.data, tir.tvm_stack_make_shape(m_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), tir.tvm_stack_make_array(C_1.data, tir.tvm_stack_make_shape(n_1, n_1, dtype=\"handle\"), 0, 2, tir.float32(0), tir.int64(0), dtype=\"handle\"), False, False, dtype=\"int32\"))\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def relu(var_rxplaceholder_1: tir.handle, var_compute_1: tir.handle) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"relu\", \"tir.noalias\": True})\n", | |
" n_1 = tir.var(\"int64\")\n", | |
" rxplaceholder_1 = tir.match_buffer(var_rxplaceholder_1, [n_1, n_1], dtype=\"float32\")\n", | |
" compute_1 = tir.match_buffer(var_compute_1, [n_1, n_1], dtype=\"float32\")\n", | |
" # body\n", | |
" # with tir.block(\"root\")\n", | |
" for i0, i1 in tir.grid(n_1, n_1):\n", | |
" with tir.block(\"compute\"):\n", | |
" i0_2, i1_2 = tir.axis.remap(\"SS\", [i0, i1])\n", | |
" tir.reads(rxplaceholder_1[i0_2, i1_2])\n", | |
" tir.writes(compute_1[i0_2, i1_2])\n", | |
" compute_1[i0_2, i1_2] = tir.max(rxplaceholder_1[i0_2, i1_2], tir.float32(0))\n", | |
" \n", | |
" @relax.function\n", | |
" def mlp(data: Tensor[(n, m), \"float32\"], weight: Tensor[(m, n), \"float32\"]) -> Tensor[_, \"float32\"]:\n", | |
" # block 0\n", | |
" gv = relax.call_tir(matmul, (data, weight), (n, n), dtype=\"float32\")\n", | |
" gv1 = relax.call_tir(relu, (gv,), (n, n), dtype=\"float32\")\n", | |
" return gv1\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# symbolic dimensions\n", | |
"n, m = tir.Var(\"n\", \"int64\"), tir.Var(\"m\", \"int64\")\n", | |
"\n", | |
"# create data and weight variables\n", | |
"data = relax.Var(\"data\", [n, m], relax.DynTensorType(2, \"float32\"))\n", | |
"weight = relax.Var(\"weight\", [m, n], relax.DynTensorType(2, \"float32\"))\n", | |
"\n", | |
"# construct a mlp model\n", | |
"mod = build_mlp(data, weight)\n", | |
"print(R.parser.astext(mod))\n", | |
"\n", | |
"# build and create vm executor\n", | |
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "8102fabf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[ 8.556568 6.87848 8.625866 9.600383 8.317718 8.227419\n", | |
" 9.219838 8.9676075 7.623468 8.23633 7.0235295 8.507209\n", | |
" 9.027352 7.290115 8.554563 7.6762805]\n", | |
" [ 6.663047 7.8030357 8.069457 8.890441 7.939424 8.519871\n", | |
" 9.274455 8.174765 7.9989533 7.5531206 8.587893 8.444625\n", | |
" 9.052873 7.5767756 8.191356 7.0712266]\n", | |
" [ 7.403376 7.207227 8.18778 8.746606 9.113302 7.779246\n", | |
" 8.831476 7.4629865 6.509424 7.559846 7.511036 7.851724\n", | |
" 8.445092 6.7876825 8.292095 7.063372 ]\n", | |
" [ 7.427354 7.0812135 7.5834956 9.232357 7.942724 8.243788\n", | |
" 8.2506075 9.12499 7.663587 8.29403 7.9033995 7.4146137\n", | |
" 8.90403 7.4565487 7.7196827 7.277466 ]\n", | |
" [ 9.678751 9.520125 11.651919 11.137583 11.08173 10.491489\n", | |
" 11.80332 10.800545 10.412642 11.5577135 9.599909 10.551059\n", | |
" 11.7693205 9.735142 11.261857 10.543016 ]\n", | |
" [ 6.8073425 5.918911 7.9557896 8.429453 7.6718554 7.034882\n", | |
" 8.698473 7.8499546 7.649053 7.733601 6.118468 7.5574713\n", | |
" 7.795342 6.8316064 7.1764045 7.349864 ]\n", | |
" [ 7.958915 6.212535 8.084226 8.112065 8.320807 8.080606\n", | |
" 9.801969 8.86682 9.270056 8.871507 7.9004436 7.482185\n", | |
" 8.870282 7.91054 8.001995 7.3347917]\n", | |
" [ 5.7960997 5.928576 7.2521763 6.63397 6.830027 5.9787736\n", | |
" 7.379688 7.4080906 4.9670863 6.8002133 5.4344077 5.761625\n", | |
" 7.6015153 5.913869 6.8978233 6.2877145]\n", | |
" [ 8.300703 7.4624095 8.288098 8.837615 9.562206 8.780277\n", | |
" 9.527164 8.587998 7.1991935 7.7677937 7.8223143 7.735983\n", | |
" 8.476541 7.170992 8.453187 8.275 ]\n", | |
" [ 8.490779 9.189245 10.063091 10.9346 9.817375 10.8089695\n", | |
" 11.861148 11.527818 9.8396015 11.470564 8.860413 9.731347\n", | |
" 11.540259 8.995022 10.125538 9.088766 ]\n", | |
" [ 7.4908876 7.6106896 8.144763 8.726535 8.755715 8.209023\n", | |
" 8.834324 7.8766665 6.288496 8.164509 7.5905747 8.019636\n", | |
" 9.01565 6.7951274 8.550508 7.674556 ]\n", | |
" [ 6.193086 5.209443 7.299016 6.642886 6.5502434 6.89572\n", | |
" 8.01048 7.1452103 6.3567533 7.611481 6.6744666 5.2749963\n", | |
" 6.601325 5.878633 7.0573955 6.89116 ]\n", | |
" [ 7.7727585 6.6690907 9.18409 8.672666 7.804694 8.393168\n", | |
" 10.026285 8.695979 7.850456 8.519929 7.7823 7.553142\n", | |
" 8.590052 7.1167426 7.9311466 8.562405 ]\n", | |
" [ 8.467529 8.3538885 10.179802 9.085586 9.400566 7.721974\n", | |
" 10.036323 8.895725 7.5192246 8.652134 7.3148284 8.582638\n", | |
" 9.310177 8.01297 9.493657 8.240374 ]\n", | |
" [ 8.044313 6.7396564 8.170473 8.387967 8.349282 8.374153\n", | |
" 9.131376 8.743106 7.3198457 8.30513 6.786698 8.167014\n", | |
" 8.164466 6.805107 8.06095 7.866236 ]\n", | |
" [ 7.4279284 7.7251635 8.799779 9.469591 8.381312 7.999055\n", | |
" 9.07975 8.277402 6.9867654 7.9436297 8.238659 8.03124\n", | |
" 9.060811 6.9540024 9.092028 7.9286 ]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# run the mlp model on relax vm\n", | |
"data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))\n", | |
"weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))\n", | |
"res = vm[\"mlp\"](data, weight)\n", | |
"print(res)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cc0e909f", | |
"metadata": {}, | |
"source": [ | |
"## Relax compilation and execution workflow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "b3ab5492", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@tvm.register_func(\"test.vm.tile\")\n", | |
"def tile_packed(a, b):\n", | |
" b[:] = tvm.nd.array(np.tile(a.asnumpy(), (1, 2)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "dfedf857", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"Original Relax Program\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n", | |
" # block 0\n", | |
" with relax.dataflow():\n", | |
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n", | |
" relax.output(y)\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"src = \"\"\"@tvm.script.ir_module\n", | |
"class InputModule:\n", | |
" @R.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor:\n", | |
" with relax.dataflow():\n", | |
" y = R.call_tir(\"test.vm.tile\", (x), (n, m * 2), dtype=\"float32\")\n", | |
" relax.output(y)\n", | |
" return y\n", | |
"\"\"\"\n", | |
"\n", | |
"# Original Relax Program\n", | |
"print(\"======================\")\n", | |
"print(\"Original Relax Program\\n\")\n", | |
"mod = R.parser.from_source(src)\n", | |
"code = R.parser.astext(mod)\n", | |
"print(code)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "9fb478f3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS0: To Non Dataflow\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n", | |
" # block 0\n", | |
" y = relax.call_tir(\"test.vm.tile\", x, (n, (m * 2)), dtype=\"float32\")\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# ToNonDataflow Pass\n", | |
"print(\"======================\")\n", | |
"print(\"PASS0: To Non Dataflow\\n\")\n", | |
"mod = relax.transform.ToNonDataflow()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "4221356d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS1: CallDPS Rewrite\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n", | |
" # block 0\n", | |
" alloc = relax.builtin.alloc_tensor((n, (m * 2)), dtype=\"float32\", attrs_type_key=\"relax.attrs.AllocTensorAttrs\")\n", | |
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# CallDPS Rewrite\n", | |
"print(\"======================\")\n", | |
"print(\"PASS1: CallDPS Rewrite\\n\")\n", | |
"mod = relax.transform.CallTIRRewrite()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "f13f5dc5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS2: Memory Lower\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @relax.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n", | |
" # block 0\n", | |
" storage = relax.vm.builtin.alloc_storage((((n * (m * 2)) * 4),), device_type=1, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n", | |
" tensor = relax.vm.builtin.alloc_tensor(storage, (n, (m * 2)), offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n", | |
" alloc = tensor\n", | |
" _ = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# Memory Lower\n", | |
"print(\"======================\")\n", | |
"print(\"PASS2: Memory Lower\\n\")\n", | |
"mod = relax.transform.VMMemoryLower()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "82592858", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"PASS3: Shape Lower\n", | |
"\n", | |
"@tvm.script.ir_module\n", | |
"class Module:\n", | |
" @tir.prim_func\n", | |
" def shape_func(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"shape_func\"})\n", | |
" # body\n", | |
" H_1[2] = H_1[0] * (H_1[1] * tir.int64(2)) * tir.int64(4)\n", | |
" \n", | |
" @tir.prim_func\n", | |
" def shape_func1(H_1: tir.Buffer[(tir.int64(4),), \"int64\"]) -> None:\n", | |
" # function attr dict\n", | |
" tir.func_attr({\"global_symbol\": \"shape_func1\"})\n", | |
" # body\n", | |
" H_1[0] = H_1[0]\n", | |
" H_1[3] = H_1[1] * tir.int64(2)\n", | |
" \n", | |
" @relax.function\n", | |
" def foo(x: Tensor[(n, m), \"float32\"]) -> Tensor[_, _]:\n", | |
" # block 0\n", | |
" shape_heap: Tensor[(4,), \"int64\"] = relax.call_packed(\"vm.builtin.alloc_shape_heap\", (4,))\n", | |
" sh = relax.call_packed(\"vm.builtin.shape_of\", x)\n", | |
" gv = relax.vm.builtin.store_shape(sh, shape_heap, indices=[0, 1], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" # block 1\n", | |
" _ = shape_func(shape_heap)\n", | |
" sh1 = relax.vm.builtin.load_shape(shape_heap, indices=[2], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" storage = relax.vm.builtin.alloc_storage(sh1, device_type=1, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocStorageAttrs\")\n", | |
" _1 = shape_func1(shape_heap)\n", | |
" sh2 = relax.vm.builtin.load_shape(shape_heap, indices=[0, 3], attrs_type_key=\"relax.attrs.ShapeHeapAttrs\")\n", | |
" tensor = relax.vm.builtin.alloc_tensor(storage, sh2, offset=0, dtype=\"float32\", attrs_type_key=\"relax.attrs.VMAllocTensorAttrs\")\n", | |
" alloc = tensor\n", | |
" _2 = relax.call_packed(\"test.vm.tile\", x, alloc)\n", | |
" y = alloc\n", | |
" return y\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# Shape Lower\n", | |
"print(\"======================\")\n", | |
"print(\"PASS3: Shape Lower\\n\")\n", | |
"mod = relax.transform.VMShapeLower()(mod)\n", | |
"print(R.parser.astext(mod))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "63e9ad0a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"======================\n", | |
"Build & Execute\n", | |
"input: [[0.46070135 0.1330114 0.8229871 0.3692288 ]\n", | |
" [0.91415226 0.24083617 0.25455388 0.40971643]\n", | |
" [0.03785264 0.329651 0.40605137 0.05099611]]\n", | |
"output: [[0.46070135 0.1330114 0.8229871 0.3692288 0.46070135 0.1330114\n", | |
" 0.8229871 0.3692288 ]\n", | |
" [0.91415226 0.24083617 0.25455388 0.40971643 0.91415226 0.24083617\n", | |
" 0.25455388 0.40971643]\n", | |
" [0.03785264 0.329651 0.40605137 0.05099611 0.03785264 0.329651\n", | |
" 0.40605137 0.05099611]]\n" | |
] | |
} | |
], | |
"source": [ | |
"# Build & Execute\n", | |
"print(\"======================\")\n", | |
"print(\"Build & Execute\")\n", | |
"\n", | |
"target = tvm.target.Target(\"llvm\", host=\"llvm\")\n", | |
"ex = relax.vm.build(mod, target)\n", | |
"vm = relax.VirtualMachine(ex, tvm.cpu())\n", | |
"\n", | |
"shape = (3, 4)\n", | |
"inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32))\n", | |
"out = vm[\"foo\"](inp)\n", | |
"print(\"input: \", inp)\n", | |
"print(\"output: \", out)\n", | |
"np.testing.assert_allclose(np.tile(inp.asnumpy(), (1, 2)), out.asnumpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2d55995e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Block [8]:
emit_te
should benn.emit_te
.Parameter
should benn.Parameter
.