# import blackjax # terrible to install
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.config import config
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
import gpjax as gpx
import numpy as np
# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
tfd = tfp.distributions
key = jr.PRNGKey(123)
plt.style.use(
"https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
import datetime
jdate = datetime.datetime.now().date()
# force CPU: save memory
jax.config.update('jax_platform_name', 'cpu')
print("Checking GPU or CPU:")
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
jax.default_backend()
jax.devices()
print("Checking done")
/nfs/scistore12/hpcgrp/jyeung/miniconda3/envs/gpjax-0.6.7/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Checking GPU or CPU:
cpu
Checking done
# simulate some data
mean1 = 6
var1 = 9
obs_noise1 = jnp.sqrt(var1)
N = 500
tmax = 1
xvec = jnp.linspace(0, tmax, N, endpoint = False).reshape(-1, 1)
yvec1 = mean1 + jax.random.normal(jr.PRNGKey(1), (N, 1)) * obs_noise1
fig, axs = plt.subplots(1)
axs.hist(yvec1.reshape(-1), bins = 50, alpha = 0.5, label = 'yvec1')
(array([ 2., 0., 1., 0., 7., 7., 4., 1., 8., 4., 12., 8., 11.,
10., 12., 18., 12., 15., 23., 21., 32., 18., 24., 12., 22., 27.,
22., 20., 17., 19., 17., 14., 15., 14., 16., 8., 7., 5., 1.,
5., 1., 4., 2., 0., 1., 0., 0., 0., 0., 1.]),
array([-2.590723 , -2.22452536, -1.85832773, -1.4921301 , -1.12593247,
-0.75973484, -0.39353721, -0.02733957, 0.33885806, 0.70505569,
1.07125332, 1.43745095, 1.80364858, 2.16984622, 2.53604385,
2.90224148, 3.26843911, 3.63463674, 4.00083437, 4.36703201,
4.73322964, 5.09942727, 5.4656249 , 5.83182253, 6.19802016,
6.5642178 , 6.93041543, 7.29661306, 7.66281069, 8.02900832,
8.39520595, 8.76140359, 9.12760122, 9.49379885, 9.85999648,
10.22619411, 10.59239174, 10.95858938, 11.32478701, 11.69098464,
12.05718227, 12.4233799 , 12.78957753, 13.15577517, 13.5219728 ,
13.88817043, 14.25436806, 14.62056569, 14.98676332, 15.35296096,
15.71915859]),
<BarContainer object of 50 artists>)

var1check = jnp.var(yvec1, ddof = 1)
print(var1check)
9.792602878664741
# fit constant kernel, check obs_noise
from dataclasses import dataclass
from gpjax.base.param import param_field
import tensorflow_probability.substrates.jax.bijectors as tfb
from gpjax.typing import (
Array,
ScalarFloat,
)
from jaxtyping import Float, Integer
@dataclass
class ConstantKernelToy(gpx.kernels.AbstractKernel):
r"""Constant kernel"""
sigma2: ScalarFloat = param_field(jnp.array(0.01), bijector=tfb.Softplus(), trainable = True)
name: str="Constant"
def __call__(self, x1: Float[Array, " D"], x2: Float[Array, " D"]) -> ScalarFloat:
r"""Compute constant kernel
# from GPFlow: https://github.com/GPflow/GPflow/blob/develop/gpflow/kernels/statics.py
The Constant (aka Bias) kernel. Functions drawn from a GP with this kernel
are constant, i.e. f(x) = c, with c ~ N(0, sigma2). The kernel equation is
k(x, y) = sigma2
where:
sigma2 is the variance parameter.
"""
return(self.sigma2.squeeze())
Data = gpx.Dataset(X=xvec, y=yvec1)
# mean1prior = jnp.float64(4)
meanf = gpx.mean_functions.Constant()
sigma2prior = jnp.float64(0.0000001)
# kernelwhite = gpx.kernels.White()
kernelconstant = ConstantKernelToy(sigma2 = sigma2prior)
# draw samples
prior = gpx.Prior(mean_function=meanf, kernel=kernelconstant)
# prior = gpx.Prior(mean_function=meanf, kernel=kernel)
xtest = jnp.float64(Data.X[:, 0]).reshape(-1, 1)
prior_dist = prior.predict(xtest)
prior_mean = prior_dist.mean()
prior_std = prior_dist.variance()
samples = prior_dist.sample(seed=key, sample_shape=(20,))
fig, ax = plt.subplots()
ax.plot(xtest, samples.T, alpha=0.5, color=cols[0], marker='o')
ax.plot(xtest, prior_mean, color=cols[1], label="Prior mean")
ax.fill_between(
xtest.flatten(),
prior_mean - prior_std,
prior_mean + prior_std,
alpha=0.3,
color=cols[1],
label="Prior variance",
)
ax.legend(loc="best")
<matplotlib.legend.Legend at 0x14e16c91b550>

likelihood = gpx.Gaussian(num_datapoints=Data.n)
posterior = prior * likelihood
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll = jax.jit(negative_mll)
negative_mll(posterior, train_data=Data)
import optax as ox
opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=Data,
optim=ox.adam(learning_rate=0.01),
num_iters=5000,
safe=True,
key=key,
)
Running: 100%|██████████| 5000/5000 [00:46<00:00, 106.86it/s, Value=1295.02]
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")
[Text(0.5, 0, 'Training iteration'),
Text(0, 0.5, 'Negative marginal log likelihood')]

latent_dist = opt_posterior.predict(xtest, train_data=Data)
predictive_dist = opt_posterior.likelihood(latent_dist)
predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
fig, ax = plt.subplots(figsize=(7.5, 2.5))
ax.plot(xvec, yvec1, "x", label="Observations", color=cols[0], alpha=0.5)
ax.fill_between(
xtest.squeeze(),
predictive_mean - 2 * predictive_std,
predictive_mean + 2 * predictive_std,
alpha=0.2,
label="Two sigma",
color=cols[1],
)
ax.plot(
xtest,
predictive_mean - 2 * predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax.plot(
xtest,
predictive_mean + 2 * predictive_std,
linestyle="--",
linewidth=1,
color=cols[1],
)
ax.plot(xtest, predictive_mean, label="Predictive mean", color=cols[1])
ax.legend(loc="center left", bbox_to_anchor=(0.975, 0.5))
<matplotlib.legend.Legend at 0x14e13a170b90>

print('Real obs noise: %s' % obs_noise1)
print('Estimated obs noise: %s' % opt_posterior.likelihood.obs_noise)
Real obs noise: 3.0
Estimated obs noise: 7.2693954