Skip to content

Instantly share code, notes, and snippets.

@vinx13
Created July 24, 2019 05:37
Show Gist options
  • Save vinx13/276e301660db6a3167eefe9e91ee7a87 to your computer and use it in GitHub Desktop.
Save vinx13/276e301660db6a3167eefe9e91ee7a87 to your computer and use it in GitHub Desktop.
import numpy as np
import argparse
import tvm
import tvm.relay as relay
import tvm.relay.testing
from tvm.relay.testing import layers
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
def create_workload(net, initializer=None, seed=0):
"""Helper function to create benchmark image classification workload.
Parameters
----------
net : tvm.relay.Function
The selected function of the network.
initializer : Initializer
The initializer used
seed : int
The seed used in initialization.
Returns
-------
mod : tvm.relay.Module
The created relay module.
params : dict of str to NDArray
The parameters.
"""
mod = relay.Module.from_expr(net)
mod = relay.transform.InferType()(mod)
shape_dict = {
v.name_hint : v.checked_type for v in mod["main"].params}
np.random.seed(seed)
initializer = initializer if initializer else relay.testing.init.Xavier()
params = {}
for k, v in shape_dict.items():
if k in ["data", "target"]:
continue
init_value = np.zeros(v.concrete_shape).astype(v.dtype)
initializer(k, init_value)
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0))
return mod, params
def get_network():
batch = 1
image_shape = (3, 224, 224)
data_shape = (batch,) + image_shape
num_classes = 10
dtype = 'float32'
data = relay.var("data", shape=data_shape, dtype=dtype)
target = relay.var("target", shape=(batch, num_classes), dtype=dtype)
flatten = relay.nn.batch_flatten(data)
fc = layers.dense_add_bias(data=flatten, units=num_classes, name="fc7")
loss = relay.nn.cross_entropy(fc, target)
ret = relay.Tuple([fc, loss])
f = relay.Function(relay.analysis.free_vars(ret), ret)
mod, params = create_workload(f)
return mod, params
def main():
mod, params = get_network()
mod = relay.transform.CanonicalizeOps()(mod)
mod = relay.transform.InferType()(mod)
mod['main'] = relay.transform.gradient(mod['main'])
mod = relay.transform.PartialEvaluate()(mod)
mod = relay.transform.DeadCodeElimination()(mod)
mod = relay.transform.ToGraphNormalForm()(mod)
relay.build(mod, params=params, target='llvm')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment