Created
July 24, 2019 05:37
-
-
Save vinx13/276e301660db6a3167eefe9e91ee7a87 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import argparse | |
import tvm | |
import tvm.relay as relay | |
import tvm.relay.testing | |
from tvm.relay.testing import layers | |
import mxnet as mx | |
from mxnet import gluon | |
from mxnet.gluon.model_zoo import vision | |
def create_workload(net, initializer=None, seed=0): | |
"""Helper function to create benchmark image classification workload. | |
Parameters | |
---------- | |
net : tvm.relay.Function | |
The selected function of the network. | |
initializer : Initializer | |
The initializer used | |
seed : int | |
The seed used in initialization. | |
Returns | |
------- | |
mod : tvm.relay.Module | |
The created relay module. | |
params : dict of str to NDArray | |
The parameters. | |
""" | |
mod = relay.Module.from_expr(net) | |
mod = relay.transform.InferType()(mod) | |
shape_dict = { | |
v.name_hint : v.checked_type for v in mod["main"].params} | |
np.random.seed(seed) | |
initializer = initializer if initializer else relay.testing.init.Xavier() | |
params = {} | |
for k, v in shape_dict.items(): | |
if k in ["data", "target"]: | |
continue | |
init_value = np.zeros(v.concrete_shape).astype(v.dtype) | |
initializer(k, init_value) | |
params[k] = tvm.nd.array(init_value, ctx=tvm.cpu(0)) | |
return mod, params | |
def get_network(): | |
batch = 1 | |
image_shape = (3, 224, 224) | |
data_shape = (batch,) + image_shape | |
num_classes = 10 | |
dtype = 'float32' | |
data = relay.var("data", shape=data_shape, dtype=dtype) | |
target = relay.var("target", shape=(batch, num_classes), dtype=dtype) | |
flatten = relay.nn.batch_flatten(data) | |
fc = layers.dense_add_bias(data=flatten, units=num_classes, name="fc7") | |
loss = relay.nn.cross_entropy(fc, target) | |
ret = relay.Tuple([fc, loss]) | |
f = relay.Function(relay.analysis.free_vars(ret), ret) | |
mod, params = create_workload(f) | |
return mod, params | |
def main(): | |
mod, params = get_network() | |
mod = relay.transform.CanonicalizeOps()(mod) | |
mod = relay.transform.InferType()(mod) | |
mod['main'] = relay.transform.gradient(mod['main']) | |
mod = relay.transform.PartialEvaluate()(mod) | |
mod = relay.transform.DeadCodeElimination()(mod) | |
mod = relay.transform.ToGraphNormalForm()(mod) | |
relay.build(mod, params=params, target='llvm') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment