Created
March 29, 2020 16:45
-
-
Save mattjj/bc31f76f9e97f6de03114ea10cb853f7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax import core | |
# A primitive is just a name to which we associate rules. | |
sincos_p = core.Primitive('sincos') | |
# A primitive's "bind" is how it gets applied, in a way that interacts with the | |
# trace/transform machinery. As a convention we wrap them in Python functions | |
# like this: | |
def sincos(x): | |
return sincos_p.bind(x) | |
# We can't do anything before we attach rules. Even evaluation is a rule. Here's | |
# how we attach an evaluation rule. | |
import numpy as onp | |
def sincos_impl(x): | |
return onp.sin(onp.cos(x)) | |
sincos_p.def_impl(sincos_impl) | |
# Now we can evaluate it: | |
print(sincos(3.)) # -0.8360218615377305 | |
# For making jaxprs and jit compilation (and a few other transforms) we need an | |
# abstract evaluation rule. An abstract evaluation rule must return an upper | |
# bound on the abstract value lattice for the output given the input. Here's a | |
# verbose way of doing it: | |
from jax.core import UnshapedArray, ShapedArray, ConcreteArray | |
def sincos_abstract_eval(x): | |
if not onp.issubdtype(x.dtype, onp.floating): | |
raise TypeError("must be floating dtype") | |
if isinstance(x, ConcreteArray): | |
return ConcreteArray(sincos_impl(x.val)) | |
elif isinstance(x, ShapedArray): | |
return ShapedArray(x.shape, x.dtype) | |
elif isinstance(x, UnshapedArray): | |
return UnshapedArray(x.dtype) | |
else: | |
raise TypeError(x) | |
sincos_p.def_abstract_eval(sincos_abstract_eval) | |
# But here's a quicker way that will work just fine. | |
from jax.core import raise_to_shaped | |
def sincos_abstract_eval(x): | |
if not onp.issubdtype(x.dtype, onp.floating): | |
raise TypeError("must be floating dtype") | |
return raise_to_shaped(x) | |
sincos_p.def_abstract_eval(sincos_abstract_eval) | |
# Now we can make jaxprs: | |
from jax import make_jaxpr | |
print(make_jaxpr(sincos)(3.)) | |
# { lambda ; a. | |
# let b = sincos a | |
# in (b,) } | |
# For jit compilation we also need an XLA translation rule. | |
from jax.interpreters import xla | |
def sincos_translation_rule(c, x): | |
# c is an XLA ComputationBuilder, x is an XlaOp representing the input | |
return c.Sin(c.Cos(x)) | |
xla.translations[sincos_p] = sincos_translation_rule | |
# Now we can jit: | |
from jax import jit | |
a = jit(lambda x: sincos(sincos(x)))(3.) | |
b = sincos(sincos(3.)) | |
print(a) # 0.62131506 | |
print(b) # 0.6213150041315272 <-- numpy impl has more bits | |
# A trick we can pull is to generate an impl from the translation rule, | |
# basically meaning "when you want to evaluate, just jit compile the translation | |
# rule by itself." Then we don't need an onp-based impl. Here's how that looks: | |
from functools import partial | |
sincos_p.def_impl(partial(xla.apply_primitive, sincos_p)) | |
print(sincos(3.)) # -0.83602184 | |
print(sincos(sincos(3.))) # 0.62131506 | |
# Finally, differentiation rules! Here's a forward-mode rule: | |
from jax.interpreters import ad | |
from jax import lax # most of our primitives live here | |
def sincos_jvp_rule(primals, tangents): | |
x, = primals | |
t, = tangents | |
out_primal = sincos(x) | |
out_tangent = t * (-lax.sin(x)) * lax.cos(lax.cos(x)) | |
return out_primal, out_tangent | |
ad.primitive_jvps[sincos_p] = sincos_jvp_rule | |
# Now we can use forawrd-mode autodiff: | |
from jax import jvp | |
y, y_dot = jvp(sincos, (3.,), (1.,)) | |
print(y) # -0.83602184 | |
print(y_dot) # -0.07743199 | |
y, y_dot = jvp(lambda x: lax.sin(lax.cos(x)), (3.,), (1.,)) | |
print(y) # -0.83602184 | |
print(y_dot) # -0.07743199 | |
# We can use reverse-mode too, since JAX does automatic transposition on our | |
# foward-mode rule to generate reverse mode: | |
from jax import grad | |
print(grad(sincos)(3.)) # -0.07743199 | |
# Custom reverse-mode rules are a bit trickier, since JAX doesn't implement | |
# reverse-mode directly. We don't do this for any JAX primitives in the core . | |
# Moreover, we can't set up a custom VJP rule *and* still keep forward-mode | |
# working. | |
# The mechanism is being replaced, but here's the not-quite-yet-deprecated way | |
# of doing it (see https://github.com/google/jax/pull/636): | |
ad.defvjp2(sincos_p, lambda g, ans, x: g * lax.cos(lax.cos(x)) * (-lax.sin(x))) | |
print(grad(sincos)(3.)) # -0.07743199 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment