Created
July 16, 2021 19:18
-
-
Save comaniac/37bfb1d707b0b371586cb42f8fc44bcd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tvm | |
from tvm import relay, auto_scheduler | |
import tvm.relay.testing | |
from tvm.contrib import graph_executor | |
def get_network(name, batch_size, layout="NHWC", dtype="float32"): | |
"""Get the symbol definition and random weight of a network""" | |
# auto-scheduler prefers NHWC layout | |
if layout == "NHWC": | |
image_shape = (224, 224, 3) | |
elif layout == "NCHW": | |
image_shape = (3, 224, 224) | |
else: | |
raise ValueError("Invalid layout: " + layout) | |
input_shape = (batch_size,) + image_shape | |
output_shape = (batch_size, 1000) | |
if name.startswith("resnet-"): | |
n_layer = int(name.split("-")[1]) | |
mod, params = relay.testing.resnet.get_workload( | |
num_layers=n_layer, | |
batch_size=batch_size, | |
layout=layout, | |
dtype=dtype, | |
image_shape=image_shape, | |
) | |
elif name.startswith("resnet3d-"): | |
n_layer = int(name.split("-")[1]) | |
mod, params = relay.testing.resnet.get_workload( | |
num_layers=n_layer, | |
batch_size=batch_size, | |
layout=layout, | |
dtype=dtype, | |
image_shape=image_shape, | |
) | |
elif name == "mobilenet": | |
mod, params = relay.testing.mobilenet.get_workload( | |
batch_size=batch_size, layout=layout, dtype=dtype, image_shape=image_shape | |
) | |
elif name == "squeezenet_v1.1": | |
assert layout == "NCHW", "squeezenet_v1.1 only supports NCHW layout" | |
mod, params = relay.testing.squeezenet.get_workload( | |
version="1.1", | |
batch_size=batch_size, | |
dtype=dtype, | |
image_shape=image_shape, | |
) | |
elif name == "inception_v3": | |
input_shape = (batch_size, 3, 299, 299) if layout == "NCHW" else (batch_size, 299, 299, 3) | |
mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype) | |
elif name == "mxnet": | |
# an example for mxnet model | |
from mxnet.gluon.model_zoo.vision import get_model | |
assert layout == "NCHW" | |
block = get_model("resnet18_v1", pretrained=True) | |
mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype) | |
net = mod["main"] | |
net = relay.Function( | |
net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs | |
) | |
mod = tvm.IRModule.from_expr(net) | |
return mod, params, input_shape, output_shape | |
# Define the neural network and compilation target | |
network = "resnet-18" | |
batch_size = 1 | |
layout = "NCHW" | |
target = {"llvm": "llvm", "cuda": "cuda"} | |
dtype = "float32" | |
log_file = "%s-%s-B%d-%s.json" % (network, layout, batch_size, "cuda") | |
mod, params, input_shape, output_shape = get_network(network, batch_size, layout, dtype=dtype) | |
@relay.transform.function_pass(opt_level=1) | |
class MyPass: | |
def __init__(self): | |
self.var = 0 | |
# This function can define a pass. | |
def transform_function(self, func, mod, ctx): | |
obj = self | |
class Test(tvm.relay.ExprMutator): | |
def visit_call(self, expr): | |
visit = super().visit_call(expr) | |
if expr.op == tvm.relay.op.get("nn.conv2d"): | |
return relay.annotation.on_device(visit, 'cuda') | |
else: | |
return visit | |
return Test().visit(func) | |
dev1 = tvm.device("llvm") | |
dev2 = tvm.device("cuda") | |
custom_pass = MyPass() | |
mod = custom_pass(mod) | |
if True: | |
# Error | |
print("Extract tasks...") | |
tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) | |
for idx, task in enumerate(tasks): | |
print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) | |
print(task.compute_dag) | |
assert False | |
else: | |
# Working | |
print("Compile...") | |
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): | |
lib = relay.build(mod, target=target, params=params) | |
# Create graph executor | |
module = graph_executor.create(lib.get_graph_json(), lib.get_lib(), [dev1, dev2]) | |
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) | |
module.set_input(**lib.get_params()) | |
module.set_input("data", data_tvm) | |
module.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment