Created
March 8, 2019 23:51
-
-
Save bwasti/ff8f754c034b8005cfedaf25f8ce2b17 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
from tvm import relay | |
import torch | |
import torch.nn.functional as F | |
import inspect | |
import ast | |
import numpy as np | |
_parsed_functions = dict() | |
def jit_assert(cond, msg="[see stack]"): | |
if not cond: | |
raise Exception(msg) | |
def get_methods(mod): | |
m = dict() | |
for method in dir(mod): | |
if method.startswith("_"): | |
continue | |
m[id(getattr(mod, method))] = method | |
return m | |
_torch_methods = get_methods(torch) | |
_tensor_methods = get_methods(torch.tensor) | |
_functional_methods = get_methods(F) | |
class RelayParser(ast.NodeVisitor): | |
def __init__(self, globals_={}): | |
self.symbols = {} # Symbol table | |
self.binds = {} | |
self.globals = globals_ | |
self.inputs = [] | |
self.input_types = [] | |
super().__init__() | |
def get_id_from_node_(self, node): | |
if isinstance(node, ast.Attribute): | |
obj = self.get_id_from_node_(node.value) | |
return getattr(obj, node.attr) | |
if isinstance(node, ast.Name): | |
return self.globals[node.id] | |
jit_assert(False, "Cannot get id from node {}".format(ast.dump(node))) | |
def get_id_from_node(self, node): | |
return id(self.get_id_from_node_(node)) | |
def torch_builtin(self, func_name, args, keywords): | |
if func_name == "ones": | |
return relay.const(tvm.ndarray.array( | |
np.ones(ast.literal_eval(args[0])).astype(np.float32) | |
)) | |
if func_name == "relu": | |
return relay.nn.relu(self.visit(args[0])) | |
if func_name == "conv2d": | |
stride = relay.const(1) | |
#pad = relay.const(0) | |
#for k in keywords: | |
# if k.arg == "stride": | |
# stride = self.visit(k.value) | |
# if k.arg == "padding": | |
# pad = self.visit(k.value) | |
r = relay.nn.conv2d(self.visit(args[0]), self.visit(args[1]), | |
strides=[stride, stride]) | |
#, padding=(pad, pad)) | |
return r | |
if func_name == "batch_norm": | |
r = relay.nn.batch_norm( | |
self.visit(args[0]), # input | |
self.visit(args[3]), # weight | |
self.visit(args[4]), # bias | |
self.visit(args[1]), # mean | |
self.visit(args[2]) # var | |
) | |
return r[0] | |
jit_assert(False, "Couldn't not match {} to torch builtin function".format(func_name)) | |
def relay_func_from_node(self, node, args, keywords): | |
f_id = self.get_id_from_node(node) | |
if f_id in _torch_methods: | |
return self.torch_builtin( | |
_torch_methods[f_id], args, keywords) | |
if f_id in _functional_methods: | |
return self.torch_builtin( | |
_functional_methods[f_id], args, keywords) | |
jit_assert(False) | |
def add_symbol(self, name, var, expr=None): | |
if name in self.symbols.keys(): | |
old = str(self.symbols[name]) | |
new = str(var) | |
jit_assert(False, "Symbol conflict [{}] {} -> {}.".format(key, old, new)) | |
self.symbols[name] = var | |
if expr: | |
self.binds[var] = expr | |
def generic_visit(self, node): | |
jit_assert(False, "Couldn't parse node {}".format(ast.dump(node))) | |
def visit_Module(self, node): | |
jit_assert(len(node.body) == 1, \ | |
"Only one-function source code will be fed to this parser!") | |
return self.visit(node.body[0]) | |
def visit_FunctionDef(self, node): | |
jit_assert(node.name not in _parsed_functions, "Conflicting function name {}".format(node.name)) | |
for i, arg in enumerate(node.args.args): | |
var = relay.Var(arg.arg) | |
self.add_symbol(arg.arg, var) | |
self.inputs.append(var) | |
ls = [self.visit(stmt) for stmt in node.body] | |
func = relay.Function(self.inputs, relay.bind(self.output, self.binds)) | |
_parsed_functions[node.name] = func | |
return func | |
def visit_Return(self, node): | |
self.output = self.visit(node.value) | |
return self.output | |
def visit_Name(self, node): | |
name = node.id | |
jit_assert(name in self.symbols, "Couldn't find variable '{}'".format(name)) | |
return self.symbols[name] | |
def visit_BinOp(self, node): | |
if isinstance(node.op, ast.Add): | |
x = self.visit(node.left) | |
y = self.visit(node.right) | |
return relay.op.add(x,y) | |
if isinstance(node.op, ast.Mult): | |
x = self.visit(node.left) | |
y = self.visit(node.right) | |
return relay.op.multiply(x,y) | |
jit_assert(False) | |
def visit_Call(self, node): | |
return self.relay_func_from_node(node.func, node.args, node.keywords) | |
def visit_Assign(self, node): | |
rhs = self.visit(node.value) | |
rhs = relay.bind(rhs, self.binds) | |
lhs = node.targets[0] | |
lhs_var = relay.var(lhs.id) | |
self.add_symbol(lhs.id, lhs_var, rhs) | |
return rhs | |
class RelayFunc(object): | |
def __init__(self, relay_expr, inputs): | |
self.expr = relay_expr | |
self.compiled = {} | |
self.grad = None | |
self.grad_compiled = {} | |
self.inputs = inputs | |
def compile_func(self, expr, shapes_list, mod=None): | |
inputs = [] | |
for i, shape in enumerate(shapes_list): | |
v = relay.Var(self.inputs[i].name_hint, relay.TensorType(shapes_list[i])) | |
inputs.append(v) | |
expr = relay.Function(inputs, relay.Call(expr, inputs)) | |
expr = relay.ir_pass.infer_type(expr) | |
expr = expr.body.op | |
graph = relay.create_executor('graph', mod=mod) | |
def f(*args): | |
return graph.evaluate(expr)(*args).asnumpy() | |
return f | |
def __call__(self, *inputs, torch_mode=False): | |
shapes = tuple([i.shape if hasattr(i, 'shape') else () for i in inputs]) | |
print(shapes) | |
if shapes not in self.compiled: | |
try: | |
self.compiled[shapes] = self.compile_func(self.expr, shapes) | |
except Exception as e: | |
print("While compiling\n", self.expr) | |
raise e | |
if torch_mode: | |
return torch.tensor(self.compiled[shapes](*[i.detach().numpy() if hasattr(i, 'numpy') else i for i in inputs])) | |
return self.compiled[shapes](*inputs) | |
def backward(self, *inputs): | |
if not self.grad: | |
self.grad = relay.ir_pass.gradient(relay_expr) | |
shapes = tuple([i.shape for i in inputs]) | |
if shapes not in self.grad_compiled: | |
grad_compiled[shapes] = self.compile_func(self.grad, shapes) | |
return self.grad_compiled[shapes](*inputs) | |
def script(fn): | |
parser = RelayParser(fn.__globals__) | |
sauce = inspect.getsource(fn) | |
relay_expr = None | |
try: | |
relay_expr = parser.visit(ast.parse(sauce)) | |
except Exception as e: | |
print("Hit exception while parsing:\n\n{}\n".format(sauce)) | |
raise e | |
f = None | |
try: | |
f = RelayFunc(relay_expr, parser.inputs) | |
except Exception as e: | |
raise e | |
return f |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment