Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created August 5, 2022 07:03
Show Gist options
  • Select an option

  • Save Lyken17/8b8eb9ec563f5cac6fa0acbf765e569e to your computer and use it in GitHub Desktop.

Select an option

Save Lyken17/8b8eb9ec563f5cac6fa0acbf765e569e to your computer and use it in GitHub Desktop.
Replace OPs in relay pass
import numpy as np
from collections import Counter
import tvm
from tvm import relay
from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor
from tvm.relay.expr_functor import ExprMutator, Call
class ReplaceOP(ExprMutator):
def __init__(self, replace_fn = None):
super().__init__()
self.replace_fn = replace_fn
def visit_call(self, call):
new_fn = self.visit(call.op)
args = []
for arg in call.args:
args.append(self.visit(arg))
if self.replace_fn:
r = self.replace_fn(call, args, call.attrs)
if r:
return r
return Call(new_fn, args, call.attrs)
def replace_op(expr, fn):
return ReplaceOP(replace_fn=fn).visit(expr)
def test_replace_op():
x = relay.var("x", shape=[1, 10])
y = relay.var("y", shape=[1, 10])
z = relay.var("z", shape=[1, 10])
out = relay.add(x, y)
out = relay.divide(out, z)
expr = relay.Function([x, y, z], out)
print(expr)
'''
fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32], %z: Tensor[(1, 10), float32]) {
%0 = add(%x, %y);
divide(%0, %z)
}
'''
def replace_fn(call, args, attrs):
if str(call.op) == "divide":
print("replacing divide to divide")
return relay.multiply(args[0], args[1])
return None
print(ReplaceOP(replace_fn=replace_fn).visit(expr))
'''
fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32], %z: Tensor[(1, 10), float32]) {
%0 = add(%x, %y);
multiply(%0, %z)
}
'''
if __name__ == "__main__":
test_replace_op()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment