Created
August 5, 2022 07:03
-
-
Save Lyken17/8b8eb9ec563f5cac6fa0acbf765e569e to your computer and use it in GitHub Desktop.
Replace OPs in relay pass
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| from collections import Counter | |
| import tvm | |
| from tvm import relay | |
| from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor | |
| from tvm.relay.expr_functor import ExprMutator, Call | |
| class ReplaceOP(ExprMutator): | |
| def __init__(self, replace_fn = None): | |
| super().__init__() | |
| self.replace_fn = replace_fn | |
| def visit_call(self, call): | |
| new_fn = self.visit(call.op) | |
| args = [] | |
| for arg in call.args: | |
| args.append(self.visit(arg)) | |
| if self.replace_fn: | |
| r = self.replace_fn(call, args, call.attrs) | |
| if r: | |
| return r | |
| return Call(new_fn, args, call.attrs) | |
| def replace_op(expr, fn): | |
| return ReplaceOP(replace_fn=fn).visit(expr) | |
| def test_replace_op(): | |
| x = relay.var("x", shape=[1, 10]) | |
| y = relay.var("y", shape=[1, 10]) | |
| z = relay.var("z", shape=[1, 10]) | |
| out = relay.add(x, y) | |
| out = relay.divide(out, z) | |
| expr = relay.Function([x, y, z], out) | |
| print(expr) | |
| ''' | |
| fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32], %z: Tensor[(1, 10), float32]) { | |
| %0 = add(%x, %y); | |
| divide(%0, %z) | |
| } | |
| ''' | |
| def replace_fn(call, args, attrs): | |
| if str(call.op) == "divide": | |
| print("replacing divide to divide") | |
| return relay.multiply(args[0], args[1]) | |
| return None | |
| print(ReplaceOP(replace_fn=replace_fn).visit(expr)) | |
| ''' | |
| fn (%x: Tensor[(1, 10), float32], %y: Tensor[(1, 10), float32], %z: Tensor[(1, 10), float32]) { | |
| %0 = add(%x, %y); | |
| multiply(%0, %z) | |
| } | |
| ''' | |
| if __name__ == "__main__": | |
| test_replace_op() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment