Last active
March 5, 2024 06:50
-
-
Save lucidfrontier45/abfc42d274155000093d59d32db57489 to your computer and use it in GitHub Desktop.
automatic generation of NumPyro prior sampling function for Flax model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.random as jrand | |
import flax.linen as nn | |
import numpyro.distributions as dist | |
def autoprior(model: nn.Module, model_args, scale=1.0, prefix: str="param"): | |
key = jrand.PRNGKey(0) | |
init_params = model.init(key, *model_args) | |
flatten_params, tree_def = jax.tree.flatten(init_params) | |
shapes = [x.shape for x in flatten_params] | |
def pack_params(param_dict: dict): | |
params = [] | |
for k, v in sorted(param_dict.items()): | |
if k.startswith(prefix): | |
params.append(v) | |
return jax.tree.unflatten(tree_def, params) | |
def sample_prior(): | |
priors = [] | |
for i, shape in enumerate(shapes): | |
name = f"{prefix}_{i:08}" | |
priors.append(numpyro.sample(name, dist.Laplace(0.0, scale).expand(shape))) | |
return jax.tree.unflatten(tree_def, priors) | |
return sample_prior, pack_params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment