Last active
April 10, 2023 18:48
-
-
Save KeAWang/04b4f9ce2235d1767d6587005906a555 to your computer and use it in GitHub Desktop.
Parameterized Arrays in Jax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import jax.numpy as jnp | |
import jax | |
import equinox as eqx | |
from typing import Union, Any | |
from abc import ABC, abstractmethod | |
MaybeParameterizedArray = Union[jax.Array, "ParameterizedArray"] | |
class ParameterizedArray(ABC, eqx.Module): | |
@abstractmethod | |
def eval(self) -> jax.Array: | |
return | |
@staticmethod | |
def tree_eval(pytree: Any) -> Any: | |
"""Turn every (possibly nested) ParameterizedArray into a single jax.Array""" | |
def _is_parameterized(x: MaybeParameterizedArray): | |
return isinstance(x, ParameterizedArray) | |
def _eval(x): | |
return x.eval() if _is_parameterized(x) else x | |
return jax.tree_map(_eval, pytree, is_leaf=_is_parameterized) | |
class PositiveArray(ParameterizedArray): | |
unconstrained_array: MaybeParameterizedArray | |
@staticmethod | |
def inv_softplus(x: jax.Array) -> jax.Array: | |
return x + jnp.log(-jnp.expm1(-x)) | |
def __init__(self, val: jax.Array): | |
self.unconstrained_array = PositiveArray.inv_softplus(val) | |
def eval(self) -> jax.Array: | |
unconstrained_array = ParameterizedArray.tree_eval(self.unconstrained_array) | |
return jax.nn.softplus(unconstrained_array) | |
shape = property(lambda self: self.unconstrainted_array.shape) | |
dtype = property(lambda self: self.unconstrainted_array.dtype) | |
class MaybeLearnableArray(ParameterizedArray): | |
array: MaybeParameterizedArray | |
learnable: bool = eqx.static_field() | |
def eval(self) -> jax.Array: | |
array = ParameterizedArray.tree_eval(self.array) | |
return jax.lax.cond(self.learnable, lambda x: x, lambda x: jax.lax.stop_gradient(x), array) | |
shape = property(lambda self: self.array.shape) | |
dtype = property(lambda self: self.array.dtype) | |
class PartiallyLearnableArray(ParameterizedArray): | |
array: MaybeParameterizedArray | |
learnable_mask: jax.Array = eqx.static_field() # 1 if learnable, 0 if not | |
def eval(self) -> jax.Array: | |
array = ParameterizedArray.tree_eval(self.array) | |
return jnp.where(self.learnable_mask, array, jax.lax.stop_gradient(self.array)) | |
shape = property(lambda self: self.array.shape) | |
dtype = property(lambda self: self.array.dtype) | |
class BoundedArray(ParameterizedArray): | |
unconstrained_array: MaybeParameterizedArray | |
lb: float = eqx.static_field() | |
ub: float = eqx.static_field() | |
CONSTRAINT_EPS: float = eqx.static_field(default=1e-8) | |
def _constrain(self, x): | |
"""Constrain (-inf, inf) to [0+EPS, 1-EPS]""" | |
constrained_x = jax.nn.sigmoid(x) | |
constrained_x = jnp.clip(constrained_x, a_min=0.0 + self.CONSTRAINT_EPS, a_max=1.0 - self.CONSTRAINT_EPS) | |
return constrained_x | |
def _unconstrain(self, x): | |
"""Unconstrain [0, 1] to (-inf, inf)""" | |
assert jnp.all(x >= 0.0) and jnp.all(x <= 1.0) | |
unconstrained_x = jnp.clip( | |
x, a_min=0.0 + self.CONSTRAINT_EPS, a_max=1.0 - self.CONSTRAINT_EPS | |
) # In case we get 0 or 1 | |
unconstrained_x = jax.scipy.special.logit(unconstrained_x) | |
return unconstrained_x | |
# TODO: separate ths into a different initializer | |
def __init__(self, val, lb, ub): | |
assert lb < ub | |
assert jnp.all(val >= lb) and jnp.all(val <= ub) | |
self.unconstrained_array = self._unconstrain((val - lb) / (ub - lb)) | |
self.lb = lb | |
self.ub = ub | |
def eval(self) -> jax.Array: | |
unconstrained_array = ParameterizedArray.tree_eval(self.unconstrained_array) | |
return self.lb + self._constrain(unconstrained_array) * (self.ub - self.lb) | |
shape = property(lambda self: self.array.shape) | |
dtype = property(lambda self: self.array.dtype) | |
if __name__ == "__main__": | |
#### Usage example | |
class InnerArray(ParameterizedArray): | |
a: MaybeParameterizedArray | |
multiplier: float = 2 | |
def eval(self) -> jax.Array: | |
# Any time your eval() requires a MaybeParameterizedArray, you must tree_eval() it first | |
a = ParameterizedArray.tree_eval(self.a) | |
return a * self.multiplier | |
class OuterArray(ParameterizedArray): | |
b: MaybeParameterizedArray | |
multiplier: jax.Array | |
def eval(self) -> jax.Array: | |
b = ParameterizedArray.tree_eval(self.b) | |
return b * jax.lax.stop_gradient(self.multiplier) | |
inner = InnerArray(jnp.array([1, 2, 3])) | |
outer = OuterArray(inner, jnp.array([1, 0, 1])) | |
model = (outer, inner, jnp.array([4, 5, 6]), True) | |
print("Outer", outer.eval()) | |
print("Inner", inner.eval()) | |
print("Evaluate PyTree of Parameterized Arrays", ParameterizedArray.tree_eval(model)) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment