Skip to content

Instantly share code, notes, and snippets.

@lucidfrontier45
Last active March 5, 2024 06:50
Show Gist options
  • Save lucidfrontier45/abfc42d274155000093d59d32db57489 to your computer and use it in GitHub Desktop.
Save lucidfrontier45/abfc42d274155000093d59d32db57489 to your computer and use it in GitHub Desktop.
automatic generation of NumPyro prior sampling function for Flax model
import jax
import jax.random as jrand
import flax.linen as nn
import numpyro.distributions as dist
def autoprior(model: nn.Module, model_args, scale=1.0, prefix: str="param"):
key = jrand.PRNGKey(0)
init_params = model.init(key, *model_args)
flatten_params, tree_def = jax.tree.flatten(init_params)
shapes = [x.shape for x in flatten_params]
def pack_params(param_dict: dict):
params = []
for k, v in sorted(param_dict.items()):
if k.startswith(prefix):
params.append(v)
return jax.tree.unflatten(tree_def, params)
def sample_prior():
priors = []
for i, shape in enumerate(shapes):
name = f"{prefix}_{i:08}"
priors.append(numpyro.sample(name, dist.Laplace(0.0, scale).expand(shape)))
return jax.tree.unflatten(tree_def, priors)
return sample_prior, pack_params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment