Skip to content

Instantly share code, notes, and snippets.

@iamlemec
Created January 13, 2025 19:18
Show Gist options
  • Save iamlemec/7d6eca7111052a6c5f3ac31bf3635d05 to your computer and use it in GitHub Desktop.
Save iamlemec/7d6eca7111052a6c5f3ac31bf3635d05 to your computer and use it in GitHub Desktop.
Trying out some different JAX options for demeaning.
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import config
def _apply_factor(x, f, w, ng):
"""Process a single factor."""
wx = x * w[:, None]
# Compute group weights and weighted sums
group_weights = jnp.bincount(f, weights=w, length=ng)
group_sums = jax.vmap(
lambda col: jnp.bincount(f, weights=col, length=ng)
)(wx.T).T
# Compute and subtract means
means = group_sums / group_weights[:, None]
return x - means[f], None
def _demean_step(x_curr, flist, weights, n_groups):
"""Single demeaning step for all factors."""
# Process all factors using scan
result, _ = jax.lax.scan(
lambda x, f: _apply_factor(x, f, weights, n_groups),
x_curr, flist.T
)
return result
def _cond_fun(state, tol, maxiter):
"""Condition function for while_loop."""
i, _, max_diff = state
return jnp.logical_and(i < maxiter, max_diff > tol)
def _body_fun(state, flist, weights, n_groups):
"""Body function for while_loop."""
i, x_curr, _ = state
x_new = _demean_step(x_curr, flist, weights, n_groups)
max_diff = jnp.max(jnp.abs(x_new - x_curr))
return i + 1, x_new, max_diff
@partial(jax.jit, static_argnames=("n_groups", "maxiter"))
def _demean_jax_impl(
x: jnp.ndarray,
flist: jnp.ndarray,
weights: jnp.ndarray,
n_groups: int,
tol: float,
maxiter: int,
) -> tuple[jnp.ndarray, bool]:
"""JIT-compiled implementation of demeaning."""
# Run the iteration loop using while_loop
_, final_x, max_diff = jax.lax.while_loop(
lambda state: _cond_fun(state, tol, maxiter),
lambda state: _body_fun(state, flist, weights, n_groups),
(0, x, 1.0)
)
return final_x, max_diff
def demean_jax(
x: np.ndarray,
flist: np.ndarray,
weights: np.ndarray,
tol: float = 1e-08,
maxiter: int = 100_000,
) -> tuple[np.ndarray, bool]:
"""Fast and reliable JAX implementation with static shapes."""
# Enable float64 precision
config.update("jax_enable_x64", True)
# Compute n_groups before JIT
n_groups = int(np.max(flist) + 1)
# Convert inputs to JAX arrays
x_jax = jnp.asarray(x, dtype=jnp.float64)
flist_jax = jnp.asarray(flist, dtype=jnp.int32)
weights_jax = jnp.asarray(weights, dtype=jnp.float64)
# Call the JIT-compiled implementation
result_jax, max_diff = _demean_jax_impl(
x_jax, flist_jax, weights_jax, n_groups, tol, maxiter
)
return np.array(result_jax), max_diff < tol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment