Skip to content

Instantly share code, notes, and snippets.

@GleasonK
Created November 7, 2024 14:53
Show Gist options
  • Save GleasonK/23a1052ccc5cf1717816c8ea02c96579 to your computer and use it in GitHub Desktop.
Save GleasonK/23a1052ccc5cf1717816c8ea02c96579 to your computer and use it in GitHub Desktop.
# NOTE: The JAX APIs for creating a composite are WIP, we are looking into a
# jax.lax.composite API -- this file as-is uses JAX internal APIs which are subject
# to change and should not be relied on for anything in production, but may be OK for
# experimentation.
"""XLA Composite Test."""
from absl.testing import absltest
import jax
from jax import export as jax_export
from jax._src import test_util as jtu
from jax._src.lax import lax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
from jax.interpreters import mlir
import jax.numpy as jnp
import jaxtyping
ArrayLike = jaxtyping.ArrayLike
Array = jaxtyping.Array
# Step 1: Define a jax primitive.
my_acos_p = lax.standard_unop(lax._float | lax._complex, "my_acos")
# Step 2: Define a jax api.
def my_acos(x: ArrayLike) -> Array:
return my_acos_p.bind(x)
# Step 3: Define auto diff rule.
lax.ad.defjvp(
my_acos_p,
lambda g, x: lax.mul(g, -lax.rsqrt(lax._const(x, 1) - lax.square(x))),
)
# Step 4: Define lowering to stablehlo.composite.
def _composite_acos_lowering(
ctx: mlir.LoweringRuleContext, arg: mlir.ir.BlockArgument
) -> mlir.ir.OpResultList:
@jax.jit
def my_acos_impl(x: ArrayLike) -> Array:
return jnp.acos(x)
# TODO: Implementation currently leaks a call op which can be DCE'd
# This will be fixed in future JAX API
lowered_fun = mlir.lower_fun(my_acos_impl, multiple_results=False)
call_op = lowered_fun(ctx, arg)[0].owner
composite = stablehlo.CompositeOp(
[result.type for result in call_op.results],
call_op.operands,
name=ir.StringAttr.get("chlo.acos"),
composite_attributes=ir.DictAttr.get({}),
decomposition=call_op.attributes["callee"],
)
return composite.results
# Step 5: Register your custom composite lowering to stablehlo.composite.
mlir.register_lowering(my_acos_p, _composite_acos_lowering)
class XlaCompositeTest(jtu.JaxTestCase):
def test_acos_composite(self):
@jax.jit
def f(x: ArrayLike) -> Array:
return my_acos(x)
x = jnp.array(1.0, dtype=jnp.float32)
self.assertAllClose(jnp.acos(x), f(x))
mlir_module = jax_export.export(f)(x).mlir_module()
self.assertIn('stablehlo.composite "chlo.acos"', mlir_module)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment