Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active May 26, 2025 17:13
Show Gist options
  • Save ricardoV94/9473a0ff3bb57adb7f53f9edad32ae7b to your computer and use it in GitHub Desktop.
Save ricardoV94/9473a0ff3bb57adb7f53f9edad32ae7b to your computer and use it in GitHub Desktop.
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytensor.xtensor as px
N = 100
seed = sum(map(ord, "xarray>=numpy?"))
rng = np.random.default_rng(seed)
x_np = np.linspace(0, 10, N)
y_np = np.piecewise(
x_np,
[x_np <= 3, (x_np > 3) & (x_np <= 7), x_np > 7],
[lambda x: 0.5 * x, lambda x: 1.5 + 0.2 * (x - 3), lambda x: 2.3 - 0.1 * (x - 7)],
)
y_np += rng.normal(0, 0.2, size=N)
group_idx = rng.choice(3, size=N)
N_knots = 13
knots_np = np.linspace(0, 10, num=N_knots)
coords = {
"group": range(3),
"knots": range(N_knots),
"obs": range(N),
}
with pm.Model(coords=coords) as model:
x = pm.Data("x", x_np, dims="obs")
knots = pm.Data("knots", knots_np, dims="knot")
sigma = pm.HalfCauchy("sigma", beta=1)
sigma_beta0 = pm.HalfNormal("sigma_beta0", sigma=10)
beta0 = pm.HalfNormal("beta_0", sigma=sigma_beta0, dims="group")
z = pm.Normal(f"z", dims=("group", "knot"))
delta_factors = pt.special.softmax(z, axis=-1) # (groups, knot)
slope_factors = 1 - pt.cumsum(delta_factors[:, :-1], axis=-1) # (groups, knot-1)
spline_slopes = pt.join(-1, beta0[:, None], beta0[:, None] * slope_factors) # (groups, knot-1)
beta = pt.join(-1, beta0[:, None], pt.diff(spline_slopes, axis=-1)) # (groups, knot)
beta = pm.Deterministic("beta", beta, dims=("group", "knot"))
X = pt.maximum(0, x[:, None] - knots[None, :]) # (n, knot)
mu = (X * beta[group_idx]).sum(-1) # ((n, knots) * (n, knots)).sum(-1) = (n,)
y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_np, dims="obs")
class XModel(pm.Model):
def register_rv(self, rv, *args, dims=None, **kwargs):
rv = super().register_rv(rv, *args, dims=dims, **kwargs)
if dims is not None:
rv = px.as_xtensor(rv, dims=dims)
return rv
def add_named_variable(self, var, dims=None):
if isinstance(var.type, px.type.XTensorType):
if dims is None:
dims = var.dims
else:
if dims != var.dims:
raise ValueError(
f"Provided dims {dims} do not match variable pre-existing {var.dims}. "
"Use rename and/or transpose to match new dims"
)
super().add_named_variable(var, dims)
def XData(name, x, *args, **kwargs):
x = pm.Data(name, x, *args, **kwargs)
model = pm.modelcontext(None)
if (dims := model.named_vars_to_dims.get(x.name, None)) is not None:
x = px.as_xtensor(x, dims=dims)
return x
with XModel(coords=coords) as xmodel:
x = XData("x", x_np, dims="obs")
knots = XData("knots", knots_np, dims="knot")
sigma = pm.HalfCauchy("sigma", beta=1)
sigma_beta0 = pm.HalfNormal("sigma_beta0", sigma=10)
beta0 = pm.HalfNormal("beta_0", sigma=sigma_beta0, dims="group")
z = pm.Normal(f"z", dims=("group", "knot"))
delta_factors = px.special.softmax(z, dim="knot")
slope_factors = 1 - delta_factors.isel(knot=np.s_[:-1]).cumsum("knot")
spline_slopes = px.concat([beta0, beta0 * slope_factors], dim="knot")
beta = px.concat([beta0, spline_slopes.diff("knot")], dim="knot")
beta = pm.Deterministic("beta", beta, dims=("group", "knot"))
X = px.math.scalar_maximum(0, x - knots)
mu = (X * beta.isel(group=group_idx).rename(group="obs")).sum("knot")
y_obs = pm.Normal("y_obs", mu=mu.values, sigma=sigma, observed=y_np, dims="obs")
print(f"{model.compile_logp()(model.initial_point()):,}")
print(f"{xmodel.compile_logp()(xmodel.initial_point()):,}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment