Created
March 12, 2020 17:23
-
-
Save mattjj/561dc690ef3cec9111f699c82ce2082b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from jax import core | |
from jax.util import safe_map, safe_zip | |
import jax.linear_util as lu | |
map = safe_map | |
zip = safe_zip | |
@lu.transformation | |
def _rewrite(rules, args): | |
with core.new_master(RewriteTrace) as master: | |
master.rules = rules | |
trace = RewriteTrace(master, core.cur_sublevel()) | |
in_tracers = map(partial(RewriteTracer, trace), args) | |
out_tracers = yield in_tracers, {} | |
outs = [trace.full_raise(t).val for t in out_tracers] | |
del master, out_tracers | |
yield outs | |
class RewriteTracer(core.Tracer): | |
__slots__ = ["_trace", "val"] | |
def __init__(self, trace, val): | |
self._trace = trace | |
self.val = val | |
@property | |
def aval(self): | |
return core.get_aval(self.val) | |
def full_lower(self): | |
return self | |
class RewriteTrace(core.Trace): | |
def pure(self, val): | |
return RewriteTracer(self, val) | |
def lift(self, val): | |
return RewriteTracer(self, val) | |
def sublift(self, val): | |
return RewriteTracer(self, val.val) | |
def process_primitive(self, primitive, tracers, params): | |
vals_in = [t.val for t in tracers] | |
if primitive in self.master.rules: | |
vals_out = rules[primitive](*vals_in, **params) | |
else: | |
vals_out = primitive.bind(*vals_in, **params) | |
if primitive.multiple_results: | |
return map(partial(RewriteTracer, self), vals_out) | |
else: | |
return RewriteTracer(self, vals_out) | |
def process_call(self, call_primitive, f, tracers, params): | |
assert False # TODO | |
def process_map(self, map_primitive, f, tracers, params): | |
assert False # TODO | |
### api.py | |
from jax.api_util import flatten_fun_nokwargs | |
from jax.tree_util import tree_flatten, tree_unflatten | |
def rewrite(fun, *args, rules): | |
args_flat, in_tree = tree_flatten(args) | |
f = lu.wrap_init(fun) | |
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) | |
out_flat = _rewrite(flat_fun, rules).call_wrapped(args_flat) | |
return tree_unflatten(out_tree(), out_flat) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Use example: