Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active April 20, 2023 13:22
Show Gist options
  • Save maedoc/d57818381ea7e6bcf11e56fb739ac061 to your computer and use it in GitHub Desktop.
Save maedoc/d57818381ea7e6bcf11e56fb739ac061 to your computer and use it in GitHub Desktop.
simple Bayesian classifier
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@maedoc
Copy link
Author

maedoc commented Apr 20, 2023

as a tl;dr in code

import jax, jax.numpy as np, jax.random as jr
import numpyro, numpyro.infer
from numpyro.distributions import Normal, Bernoulli

key = jr.PRNGKey(42)
n_class = 5
n_obs_dim = 20
n_obs = 40
confusion = 0.1
class_means = jr.normal(key, shape=(n_class, n_obs_dim))
obs_classes = jr.randint(key, shape=(n_obs,), minval=0, maxval=n_class)
obs = class_means[obs_classes] + confusion*jr.normal(key, shape=(n_obs, n_obs_dim))
obs_classes_oh = jax.nn.one_hot(obs_classes, num_classes=n_class).T
    
def log_p():
    # sample parameters
    mix = numpyro.sample('mix', Normal(np.zeros((n_class, n_obs_dim)), 1))
    offset = numpyro.sample('offset', Normal(np.zeros((n_class,1)), 1))
    params = mix, offset
    # forward model
    probs = jax.nn.softmax(mix@obs.T + offset, axis=0)
    # Bayesian loss
    numpyro.sample('obs_classes_oh',
                   Bernoulli(probs=probs),
                   obs=obs_classes_oh)    

guide = numpyro.infer.autoguide.AutoDiagonalNormal(log_p)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = numpyro.infer.SVI(log_p, guide, optimizer, 
                        loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(key, 5000, progress_bar=False)

loc = svi_result.params['auto_loc']
scale = svi_result.params['auto_scale']
z_mix = (loc[:100].reshape(mix.shape) - mix) / scale[:100].reshape(mix.shape)

assert np.percentile(np.abs(z_mix), 50) < 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment