Last active
September 10, 2017 15:16
-
-
Save mattjj/68ffb068595f629e9921edda532dcc6b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autograd.core import primitive | |
from autograd.container_types import make_tuple | |
from autograd.util import quick_grad_check | |
import autograd.numpy as np | |
@primitive | |
def _f(tupl): | |
a, b = tupl | |
return 2.*a + 3.*b | |
_f.defvjp(lambda g, ans, vs, gvs, tupl: (2.*g, 3.*g)) | |
# for caller's convenience, ensure we pass in a boxed tuple rather than a tuple | |
# of boxes | |
def f(tupl): | |
return _f(make_tuple(*tupl)) | |
def g(x): | |
return f((x, 2.*x)) | |
quick_grad_check(g, 1.) | |
# here's a second version with a slightly different f signature | |
def f2(*args): | |
return _f(make_tuple(*args)) | |
def g2(x): | |
return f2(x, 2.*x) | |
quick_grad_check(g, 1.) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment