Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created August 15, 2022 17:05
Show Gist options
  • Select an option

  • Save Lyken17/fb31f72c55e2969efd0b25acc67be606 to your computer and use it in GitHub Desktop.

Select an option

Save Lyken17/fb31f72c55e2969efd0b25acc67be606 to your computer and use it in GitHub Desktop.
relay_ast_replace_div2mul
import numpy as np
from collections import Counter
import tvm
from tvm import relay
# from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
from tvm.relay.expr_functor import ExprMutator, Call
class ReplaceOP(ExprMutator):
def visit_call(self, call):
new_fn = self.visit(call.op)
args = []
for arg in call.args:
args.append(self.visit(arg))
if str(call.op) == "divide" and isinstance(call.args[1], relay.expr.Constant):
print("replacing divide to multiply")
return relay.multiply(args[0], relay.const(1.0 / call.args[1].data.numpy()) )
return Call(new_fn, args, call.attrs)
def test_replace_op():
x = relay.var("x", shape=[1, 10])
out = relay.divide(x, relay.const(0.5))
expr = relay.Function([x, ], out)
print(expr)
'''
fn (%x: Tensor[(1, 10), float32]) {
divide(%x, 0.5f)
}
'''
print(ReplaceOP().visit(expr))
'''
fn (%x: Tensor[(1, 10), float32]) {
multiply(%x, 2f)
}
'''
if __name__ == "__main__":
test_replace_op()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment