Last active
November 9, 2017 05:50
-
-
Save mattjj/0877a878c5d5318df99dd5f5c0350df3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
import autograd | |
from autograd.tracer import trace, toposort | |
from autograd.core import VJPNode, add_outgrads | |
### closure conversion | |
def closure_conversion(f): | |
code, globs = f.func_code, f.func_globals | |
env = tuple(c.cell_contents for c in f.func_closure or ()) | |
make_cell = lambda val: (lambda: val).func_closure[0] | |
def f_maker(env): | |
closure = tuple(make_cell(val) for val in env) | |
return types.FunctionType(code, globs, closure=closure) | |
return f_maker, env | |
### split vjp | |
def make_split_vjp(fun, x): | |
start_node = VJPNode.new_root(x) | |
end_value, end_node = trace(start_node, fun, x) | |
vjp, tape = flatten_vjp(end_node) | |
return vjp, tape, end_value | |
def flatten_vjp(end_node): | |
def vjp_noclosures(g, tape): | |
outgrads = {tape[0][0]: (g, False)} | |
for id_, parents, vjp, env in tape: | |
outgrad = outgrads.pop(id_) | |
ingrads = vjp(env)(outgrad[0]) | |
for parent, ingrad in zip(parents, ingrads): | |
outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad) | |
return outgrad[0] | |
tape = [(id(node), map(id, node.parents)) + closure_conversion(node.vjp) | |
for node in toposort(end_node)] | |
return vjp_noclosures, tape | |
### script | |
if __name__ == '__main__': | |
vjp, tape, val = make_split_vjp(lambda x: x**2, 1.) | |
print val | |
print vjp(1., tape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment