Created
June 22, 2025 23:31
-
-
Save kstoneriv3/07eb4f52b870333586223e586b1aef98 to your computer and use it in GitHub Desktop.
compare different penalties to the NN training objective
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S uv run --script | |
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = ["numpy", "jax", "equinox", "matplotlib", "optax", "scikit-learn", "xgboost"] | |
# /// | |
import logging | |
from concurrent.futures import ThreadPoolExecutor | |
from itertools import chain, product | |
import equinox as eqx | |
import jax | |
import jax.numpy as jnp | |
import matplotlib.animation as animation | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import optax | |
from matplotlib.animation import FFMpegWriter | |
from sklearn.metrics import mean_squared_error | |
from sklearn.model_selection import train_test_split | |
from xgboost import XGBRegressor | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
SEED = 42 | |
N_SAMPLES = 1000 | |
TEST_SIZE = 0.9 | |
XGB_PARAMS = { | |
"n_estimators": 100, | |
"max_depth": 3, | |
"learning_rate": 0.1, | |
"random_state": SEED, | |
} | |
NN_CONFIG = { | |
"width_size": 128, | |
"depth": 2, | |
"activation": jax.nn.gelu, | |
"learning_rate": 1e-2, | |
"epochs": 501, | |
"plot_every_epochs": 25, | |
# "epochs": 5, | |
# "plot_every_epochs": 1, | |
} | |
PENALTIES = ["Naive", "GNP", "GVP", "GLP"] | |
ALPHAS = np.geomspace(0.01, 0.3, 5) | |
ALPHA_LINESTYLES = { | |
a: s for a, s in | |
zip( | |
ALPHAS[::-1], | |
[ | |
(0, ()), # "solid" | |
(0, (5, 1)), # "densely dashed" | |
(0, (5, 5)), # "dashed" | |
(0, (5, 10)), # "loosely dashed" | |
(0, (1, 1)), # "densely dotted" | |
(0, (1, 2)), # "dotted" | |
(0, (1, 5)), # "loosely dotted" | |
], | |
) | |
} | |
CMAP = [plt.get_cmap("tab10")(i) for i in range(10)] | |
CMAP_PEN = {p: CMAP[i] for i, p in enumerate(["XGB"] + PENALTIES)} | |
# --- Data Generation --- | |
def generate_data(seed: int = SEED): | |
np.random.seed(seed) | |
X = np.linspace(-2 * np.pi, 2 * np.pi, N_SAMPLES).reshape(-1, 1) | |
y = np.sin(X).ravel() #+ np.sin(X / 2).ravel() | |
y_noise = y + np.random.normal(0, 1, X.shape[0]) | |
X_train, X_test, y_train, y_test = train_test_split(X, y_noise, test_size=TEST_SIZE, random_state=seed) | |
return X_train, X_test, X, y_train, y_test, y | |
# --- XGBoost Training --- | |
def train_xgb(X_train, y_train, X_test, y_test): | |
model = XGBRegressor(**XGB_PARAMS) | |
model.fit(X_train, y_train) | |
return model | |
# --- Equinox NN Setup --- | |
def build_model(key): | |
return eqx.nn.MLP( | |
in_size=1, | |
out_size=1, | |
width_size=NN_CONFIG["width_size"], | |
depth=NN_CONFIG["depth"], | |
activation=NN_CONFIG["activation"], | |
key=key, | |
) | |
optimizer = optax.adam(learning_rate=NN_CONFIG["learning_rate"]) | |
# --- Penalty Utilities --- | |
def norm(pytree): | |
leaves = jax.tree.leaves(eqx.filter(pytree, eqx.is_inexact_array)) | |
return jnp.sqrt(sum(jnp.sum(a ** 2) for a in leaves)) | |
def squared_l2_norm_diff(p0, p1): | |
leaves0 = jax.tree.leaves(eqx.filter(p0, eqx.is_inexact_array)) | |
leaves1 = jax.tree.leaves(eqx.filter(p1, eqx.is_inexact_array)) | |
return sum(jnp.sum((a - b) ** 2) for a, b in zip(leaves0, leaves1)) | |
EPS_SCALE = 1e-1 | |
def sample_smoothness(model, x, key): | |
# gradient lipschitzness penalty approx | |
eps = jax.random.normal(key, shape=x.shape) | |
deps = jax.random.normal(key, shape=x.shape) * EPS_SCALE | |
sample_diff_grad = jax.grad( | |
lambda x: (model(x + deps) - model(x - deps)).squeeze() | |
)(x + eps) | |
return jnp.linalg.norm(sample_diff_grad / (2 * EPS_SCALE)) | |
# probably not a great idea... | |
# what if grad is constant, like a cone? | |
grad_norm_grad = jax.grad( | |
lambda x: jnp.linalg.norm( | |
jax.grad(lambda x: model(x).squeeze())(x) | |
) | |
)(x + eps) | |
return jnp.linalg.norm(grad_norm_grad) | |
# --- Loss and Update --- | |
def loss_fn(model, x, y, key, penalty, alpha=0.1): | |
y_pred = jax.vmap(model)(x).flatten() | |
loss = jnp.mean((y_pred - y) ** 2) | |
if penalty == "GNP": | |
model_grad = eqx.filter_grad(lambda m: jnp.mean((jax.vmap(m)(x) - y) ** 2))(model) | |
loss += alpha * norm(model_grad) | |
elif penalty == "GVP": | |
n = x.shape[0] | |
assert n % 2 == 0, "batch size must be even for GVP" | |
idx = jax.random.choice(key, n, (n // 2,), replace=False) | |
mask = jnp.zeros(n, bool).at[idx].set(True) | |
model_grad0 = eqx.filter_grad(lambda m: jnp.mean(mask * (eqx.filter_vmap(m)(x) - y) ** 2))(model) | |
model_grad1 = eqx.filter_grad(lambda m: jnp.mean(((1 - mask) * eqx.filter_vmap(m)(x) - y) ** 2))(model) | |
loss += alpha * (squared_l2_norm_diff(model_grad0, model_grad1) / n) ** 0.5 | |
elif penalty == "GLP": | |
keys = jax.random.split(key, x.shape[0]) | |
smoothness = jax.vmap(lambda x, k: sample_smoothness(model, x, k))(x, keys) | |
loss += alpha * jnp.mean(smoothness) | |
return loss | |
@eqx.filter_jit | |
def update_step(model, opt_state, x, y, epoch, penalty, alpha): | |
key = jax.random.PRNGKey(epoch) | |
loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y, key, penalty, alpha) | |
grads = eqx.filter(grads, eqx.is_array) | |
updates, opt_state = optimizer.update(grads, opt_state, model) | |
model = eqx.apply_updates(model, updates) | |
return model, opt_state, loss | |
# --- Main Execution --- | |
def main(): | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
) | |
# 1. Data | |
X_train, X_test, X_plot, y_train, y_test, y_plot = generate_data() | |
# 2. XGB | |
model_xgb = train_xgb(X_train, y_train, X_test, y_test) | |
y_test_pred = model_xgb.predict(X_test) | |
mse_test_xgb = mean_squared_error(y_test, y_test_pred) | |
y_plot_pred = model_xgb.predict(X_plot) | |
mse_plot_xgb = mean_squared_error(y_plot, y_plot_pred) | |
logger.info(f"XGB Test MSE: {mse_test_xgb:.3f}") | |
logger.info(f"XGB GT MSE: {mse_plot_xgb:.3f}") | |
results = { | |
("XGB", None, None): (model_xgb, mse_test_xgb, mse_plot_xgb, y_plot_pred) | |
} | |
# 3. NN variants | |
# --- Training Loop --- | |
X_train, X_test, X_plot, y_train, y_test, y_plot = map( | |
jnp.asarray, | |
[X_train, X_test, X_plot, y_train, y_test, y_plot], | |
) | |
def train_nn(penalty, alpha=0.1): | |
key = jax.random.PRNGKey(SEED) | |
model = build_model(key) | |
opt_state = optimizer.init(eqx.filter(model, eqx.is_array)) | |
for epoch in range(NN_CONFIG["epochs"]): | |
model, opt_state, loss = update_step( | |
model, opt_state, X_train, y_train, epoch, penalty, alpha | |
) | |
if epoch % NN_CONFIG["plot_every_epochs"] == 0: | |
y_test_pred = np.array(eqx.filter_vmap(model)(X_test)).ravel() | |
mse_test = mean_squared_error(y_test, y_test_pred) | |
y_plot_pred = np.array(eqx.filter_vmap(model)(X_plot)).ravel() | |
mse_plot = mean_squared_error(y_plot, y_plot_pred) | |
results[(penalty, alpha, epoch)] = (model, mse_test, mse_plot, y_plot_pred) | |
logger.info( | |
f"Penalty={penalty}, alpha={alpha:.3f}: " | |
f"Epoch {epoch}, " | |
f"Train Loss={loss:.3f}, " | |
f"Test MSE={mse_test:.3f}, " | |
f"Plot MSE={mse_plot:.3f}" | |
) | |
logger.info(f"### {penalty}-{alpha:.3f} NN Test MSE: {mse_test:.3f} ###") | |
return model | |
with ThreadPoolExecutor(20) as ex: | |
list(ex.map(train_nn, *zip(*chain([("Naive", 0)], product(PENALTIES[1:], ALPHAS))))) | |
results = dict(sorted(results.items())) | |
# 4. Plot | |
fig, ax = plt.subplots(figsize=(20,15)) | |
ims = [] | |
for epoch in [e for e in range(NN_CONFIG["epochs"]) if e % NN_CONFIG["plot_every_epochs"] == 0]: | |
art0 = ax.scatter(X_test, y_test, alpha=0.3, color=CMAP[0]) | |
art1 = ax.scatter(X_train, y_train, alpha=0.3, color=CMAP[1]) | |
art2, = ax.plot(X_plot, y_plot, color="black") | |
im = [art0, art1, art2] | |
legend = ["Noisy Test", "Noisy Train", "Ground Truth"] | |
for (penalty, alpha, e), (model, mse_test, mse_plot, y_plot_pred) in results.items(): | |
if penalty == "XGB" or e == epoch: | |
art, = ax.plot( | |
X_plot, | |
y_plot_pred, | |
alpha=0.5, | |
linestyle=ALPHA_LINESTYLES[alpha] if alpha else "solid", | |
color=CMAP_PEN[penalty], | |
) | |
im.append(art) | |
legend.append( | |
f"{penalty}-{alpha or 0:.3f} (Test MSE={mse_test:.3f}, GT MSE={mse_plot:.3f})" | |
) | |
ims.append(im) | |
plt.legend(im, legend) | |
plt.xlabel("x") | |
plt.ylabel("y") | |
plt.title("Method Comparison: XGB vs NN variants") | |
# plt.savefig("penalty_comparison.png") | |
# plt.savefig("penalty_comparison.svg") | |
# Build animation | |
ani = animation.ArtistAnimation(fig, ims, interval=200, blit=True) | |
# Save as MP4 | |
writer = FFMpegWriter(fps=2.5, metadata=dict(artist='Me'), bitrate=1800) | |
ani.save('training_progress.mp4', writer=writer) | |
plt.figure(figsize=(20, 15)) | |
epochs = [ | |
e for e in range(NN_CONFIG["epochs"]) | |
if e % NN_CONFIG["plot_every_epochs"] == 0 | |
] | |
for penalty, alpha in chain([("Naive", 0)], product(PENALTIES[1:], ALPHAS)): | |
mses = [] | |
for e in epochs: | |
model, mse_test, mse_plot, y_plot_pred = results[(penalty, alpha, e)] | |
mses.append(mse_plot) | |
plt.plot( | |
epochs, | |
mses, | |
alpha=0.5, | |
linestyle=ALPHA_LINESTYLES[alpha] if alpha else "solid", | |
color=CMAP_PEN[penalty], | |
label=f"{penalty}-{alpha or 0:.3f}", | |
) | |
plt.legend() | |
plt.xlabel("epochs") | |
plt.ylabel("Ground Truth MSE") | |
plt.xscale("log") | |
plt.yscale("log") | |
plt.ylim(0.001, 2.0) | |
plt.savefig("gt_mse.png") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment