Skip to content

Instantly share code, notes, and snippets.

@mattjj
Last active November 9, 2017 05:50
Show Gist options
  • Save mattjj/0877a878c5d5318df99dd5f5c0350df3 to your computer and use it in GitHub Desktop.
Save mattjj/0877a878c5d5318df99dd5f5c0350df3 to your computer and use it in GitHub Desktop.
import types
import autograd
from autograd.tracer import trace, toposort
from autograd.core import VJPNode, add_outgrads
### closure conversion
def closure_conversion(f):
code, globs = f.func_code, f.func_globals
env = tuple(c.cell_contents for c in f.func_closure or ())
make_cell = lambda val: (lambda: val).func_closure[0]
def f_maker(env):
closure = tuple(make_cell(val) for val in env)
return types.FunctionType(code, globs, closure=closure)
return f_maker, env
### split vjp
def make_split_vjp(fun, x):
start_node = VJPNode.new_root(x)
end_value, end_node = trace(start_node, fun, x)
vjp, tape = flatten_vjp(end_node)
return vjp, tape, end_value
def flatten_vjp(end_node):
def vjp_noclosures(g, tape):
outgrads = {tape[0][0]: (g, False)}
for id_, parents, vjp, env in tape:
outgrad = outgrads.pop(id_)
ingrads = vjp(env)(outgrad[0])
for parent, ingrad in zip(parents, ingrads):
outgrads[parent] = add_outgrads(outgrads.get(parent), ingrad)
return outgrad[0]
tape = [(id(node), map(id, node.parents)) + closure_conversion(node.vjp)
for node in toposort(end_node)]
return vjp_noclosures, tape
### script
if __name__ == '__main__':
vjp, tape, val = make_split_vjp(lambda x: x**2, 1.)
print val
print vjp(1., tape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment