Skip to content

Instantly share code, notes, and snippets.

@kstoneriv3
Created June 22, 2025 23:31
Show Gist options
  • Save kstoneriv3/07eb4f52b870333586223e586b1aef98 to your computer and use it in GitHub Desktop.
Save kstoneriv3/07eb4f52b870333586223e586b1aef98 to your computer and use it in GitHub Desktop.
compare different penalties to the NN training objective
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = ["numpy", "jax", "equinox", "matplotlib", "optax", "scikit-learn", "xgboost"]
# ///
import logging
from concurrent.futures import ThreadPoolExecutor
from itertools import chain, product
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import optax
from matplotlib.animation import FFMpegWriter
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor
logger = logging.getLogger(__name__)
# --- Configuration ---
SEED = 42
N_SAMPLES = 1000
TEST_SIZE = 0.9
XGB_PARAMS = {
"n_estimators": 100,
"max_depth": 3,
"learning_rate": 0.1,
"random_state": SEED,
}
NN_CONFIG = {
"width_size": 128,
"depth": 2,
"activation": jax.nn.gelu,
"learning_rate": 1e-2,
"epochs": 501,
"plot_every_epochs": 25,
# "epochs": 5,
# "plot_every_epochs": 1,
}
PENALTIES = ["Naive", "GNP", "GVP", "GLP"]
ALPHAS = np.geomspace(0.01, 0.3, 5)
ALPHA_LINESTYLES = {
a: s for a, s in
zip(
ALPHAS[::-1],
[
(0, ()), # "solid"
(0, (5, 1)), # "densely dashed"
(0, (5, 5)), # "dashed"
(0, (5, 10)), # "loosely dashed"
(0, (1, 1)), # "densely dotted"
(0, (1, 2)), # "dotted"
(0, (1, 5)), # "loosely dotted"
],
)
}
CMAP = [plt.get_cmap("tab10")(i) for i in range(10)]
CMAP_PEN = {p: CMAP[i] for i, p in enumerate(["XGB"] + PENALTIES)}
# --- Data Generation ---
def generate_data(seed: int = SEED):
np.random.seed(seed)
X = np.linspace(-2 * np.pi, 2 * np.pi, N_SAMPLES).reshape(-1, 1)
y = np.sin(X).ravel() #+ np.sin(X / 2).ravel()
y_noise = y + np.random.normal(0, 1, X.shape[0])
X_train, X_test, y_train, y_test = train_test_split(X, y_noise, test_size=TEST_SIZE, random_state=seed)
return X_train, X_test, X, y_train, y_test, y
# --- XGBoost Training ---
def train_xgb(X_train, y_train, X_test, y_test):
model = XGBRegressor(**XGB_PARAMS)
model.fit(X_train, y_train)
return model
# --- Equinox NN Setup ---
def build_model(key):
return eqx.nn.MLP(
in_size=1,
out_size=1,
width_size=NN_CONFIG["width_size"],
depth=NN_CONFIG["depth"],
activation=NN_CONFIG["activation"],
key=key,
)
optimizer = optax.adam(learning_rate=NN_CONFIG["learning_rate"])
# --- Penalty Utilities ---
def norm(pytree):
leaves = jax.tree.leaves(eqx.filter(pytree, eqx.is_inexact_array))
return jnp.sqrt(sum(jnp.sum(a ** 2) for a in leaves))
def squared_l2_norm_diff(p0, p1):
leaves0 = jax.tree.leaves(eqx.filter(p0, eqx.is_inexact_array))
leaves1 = jax.tree.leaves(eqx.filter(p1, eqx.is_inexact_array))
return sum(jnp.sum((a - b) ** 2) for a, b in zip(leaves0, leaves1))
EPS_SCALE = 1e-1
def sample_smoothness(model, x, key):
# gradient lipschitzness penalty approx
eps = jax.random.normal(key, shape=x.shape)
deps = jax.random.normal(key, shape=x.shape) * EPS_SCALE
sample_diff_grad = jax.grad(
lambda x: (model(x + deps) - model(x - deps)).squeeze()
)(x + eps)
return jnp.linalg.norm(sample_diff_grad / (2 * EPS_SCALE))
# probably not a great idea...
# what if grad is constant, like a cone?
grad_norm_grad = jax.grad(
lambda x: jnp.linalg.norm(
jax.grad(lambda x: model(x).squeeze())(x)
)
)(x + eps)
return jnp.linalg.norm(grad_norm_grad)
# --- Loss and Update ---
def loss_fn(model, x, y, key, penalty, alpha=0.1):
y_pred = jax.vmap(model)(x).flatten()
loss = jnp.mean((y_pred - y) ** 2)
if penalty == "GNP":
model_grad = eqx.filter_grad(lambda m: jnp.mean((jax.vmap(m)(x) - y) ** 2))(model)
loss += alpha * norm(model_grad)
elif penalty == "GVP":
n = x.shape[0]
assert n % 2 == 0, "batch size must be even for GVP"
idx = jax.random.choice(key, n, (n // 2,), replace=False)
mask = jnp.zeros(n, bool).at[idx].set(True)
model_grad0 = eqx.filter_grad(lambda m: jnp.mean(mask * (eqx.filter_vmap(m)(x) - y) ** 2))(model)
model_grad1 = eqx.filter_grad(lambda m: jnp.mean(((1 - mask) * eqx.filter_vmap(m)(x) - y) ** 2))(model)
loss += alpha * (squared_l2_norm_diff(model_grad0, model_grad1) / n) ** 0.5
elif penalty == "GLP":
keys = jax.random.split(key, x.shape[0])
smoothness = jax.vmap(lambda x, k: sample_smoothness(model, x, k))(x, keys)
loss += alpha * jnp.mean(smoothness)
return loss
@eqx.filter_jit
def update_step(model, opt_state, x, y, epoch, penalty, alpha):
key = jax.random.PRNGKey(epoch)
loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y, key, penalty, alpha)
grads = eqx.filter(grads, eqx.is_array)
updates, opt_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
# --- Main Execution ---
def main():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
# 1. Data
X_train, X_test, X_plot, y_train, y_test, y_plot = generate_data()
# 2. XGB
model_xgb = train_xgb(X_train, y_train, X_test, y_test)
y_test_pred = model_xgb.predict(X_test)
mse_test_xgb = mean_squared_error(y_test, y_test_pred)
y_plot_pred = model_xgb.predict(X_plot)
mse_plot_xgb = mean_squared_error(y_plot, y_plot_pred)
logger.info(f"XGB Test MSE: {mse_test_xgb:.3f}")
logger.info(f"XGB GT MSE: {mse_plot_xgb:.3f}")
results = {
("XGB", None, None): (model_xgb, mse_test_xgb, mse_plot_xgb, y_plot_pred)
}
# 3. NN variants
# --- Training Loop ---
X_train, X_test, X_plot, y_train, y_test, y_plot = map(
jnp.asarray,
[X_train, X_test, X_plot, y_train, y_test, y_plot],
)
def train_nn(penalty, alpha=0.1):
key = jax.random.PRNGKey(SEED)
model = build_model(key)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
for epoch in range(NN_CONFIG["epochs"]):
model, opt_state, loss = update_step(
model, opt_state, X_train, y_train, epoch, penalty, alpha
)
if epoch % NN_CONFIG["plot_every_epochs"] == 0:
y_test_pred = np.array(eqx.filter_vmap(model)(X_test)).ravel()
mse_test = mean_squared_error(y_test, y_test_pred)
y_plot_pred = np.array(eqx.filter_vmap(model)(X_plot)).ravel()
mse_plot = mean_squared_error(y_plot, y_plot_pred)
results[(penalty, alpha, epoch)] = (model, mse_test, mse_plot, y_plot_pred)
logger.info(
f"Penalty={penalty}, alpha={alpha:.3f}: "
f"Epoch {epoch}, "
f"Train Loss={loss:.3f}, "
f"Test MSE={mse_test:.3f}, "
f"Plot MSE={mse_plot:.3f}"
)
logger.info(f"### {penalty}-{alpha:.3f} NN Test MSE: {mse_test:.3f} ###")
return model
with ThreadPoolExecutor(20) as ex:
list(ex.map(train_nn, *zip(*chain([("Naive", 0)], product(PENALTIES[1:], ALPHAS)))))
results = dict(sorted(results.items()))
# 4. Plot
fig, ax = plt.subplots(figsize=(20,15))
ims = []
for epoch in [e for e in range(NN_CONFIG["epochs"]) if e % NN_CONFIG["plot_every_epochs"] == 0]:
art0 = ax.scatter(X_test, y_test, alpha=0.3, color=CMAP[0])
art1 = ax.scatter(X_train, y_train, alpha=0.3, color=CMAP[1])
art2, = ax.plot(X_plot, y_plot, color="black")
im = [art0, art1, art2]
legend = ["Noisy Test", "Noisy Train", "Ground Truth"]
for (penalty, alpha, e), (model, mse_test, mse_plot, y_plot_pred) in results.items():
if penalty == "XGB" or e == epoch:
art, = ax.plot(
X_plot,
y_plot_pred,
alpha=0.5,
linestyle=ALPHA_LINESTYLES[alpha] if alpha else "solid",
color=CMAP_PEN[penalty],
)
im.append(art)
legend.append(
f"{penalty}-{alpha or 0:.3f} (Test MSE={mse_test:.3f}, GT MSE={mse_plot:.3f})"
)
ims.append(im)
plt.legend(im, legend)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Method Comparison: XGB vs NN variants")
# plt.savefig("penalty_comparison.png")
# plt.savefig("penalty_comparison.svg")
# Build animation
ani = animation.ArtistAnimation(fig, ims, interval=200, blit=True)
# Save as MP4
writer = FFMpegWriter(fps=2.5, metadata=dict(artist='Me'), bitrate=1800)
ani.save('training_progress.mp4', writer=writer)
plt.figure(figsize=(20, 15))
epochs = [
e for e in range(NN_CONFIG["epochs"])
if e % NN_CONFIG["plot_every_epochs"] == 0
]
for penalty, alpha in chain([("Naive", 0)], product(PENALTIES[1:], ALPHAS)):
mses = []
for e in epochs:
model, mse_test, mse_plot, y_plot_pred = results[(penalty, alpha, e)]
mses.append(mse_plot)
plt.plot(
epochs,
mses,
alpha=0.5,
linestyle=ALPHA_LINESTYLES[alpha] if alpha else "solid",
color=CMAP_PEN[penalty],
label=f"{penalty}-{alpha or 0:.3f}",
)
plt.legend()
plt.xlabel("epochs")
plt.ylabel("Ground Truth MSE")
plt.xscale("log")
plt.yscale("log")
plt.ylim(0.001, 2.0)
plt.savefig("gt_mse.png")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment