Skip to content

Instantly share code, notes, and snippets.

@mattjj
Last active September 10, 2017 15:16
Show Gist options
  • Save mattjj/68ffb068595f629e9921edda532dcc6b to your computer and use it in GitHub Desktop.
Save mattjj/68ffb068595f629e9921edda532dcc6b to your computer and use it in GitHub Desktop.
from autograd.core import primitive
from autograd.container_types import make_tuple
from autograd.util import quick_grad_check
import autograd.numpy as np
@primitive
def _f(tupl):
a, b = tupl
return 2.*a + 3.*b
_f.defvjp(lambda g, ans, vs, gvs, tupl: (2.*g, 3.*g))
# for caller's convenience, ensure we pass in a boxed tuple rather than a tuple
# of boxes
def f(tupl):
return _f(make_tuple(*tupl))
def g(x):
return f((x, 2.*x))
quick_grad_check(g, 1.)
# here's a second version with a slightly different f signature
def f2(*args):
return _f(make_tuple(*args))
def g2(x):
return f2(x, 2.*x)
quick_grad_check(g, 1.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment