Skip to content

Instantly share code, notes, and snippets.

@dlwh
Last active May 23, 2025 19:21
Show Gist options
  • Save dlwh/0f16ccb33a769997eb880d433d1f71a7 to your computer and use it in GitHub Desktop.
Save dlwh/0f16ccb33a769997eb880d433d1f71a7 to your computer and use it in GitHub Desktop.
llama3 vs neox bpb regressor
Coefficients: slope=0.083992
Intercept: -0.011938
R² score: 0.729319
Saved metrics_differences_with_predictions.csv
Predicted BPB for run2: 0.726887 ± 0.000513 BPB
Process finished with exit code 0
import wandb
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt
domain = "c4_en"
# domain = "dolma_100_programing_languages"
# Set your project and run IDs
project_name = "marin-community/marin"
run1_id = "llama3-tokenizer-095cea"
run2_id = "neox-tokenizer-ad549d"
loss_key = f"eval/paloma/{domain}/loss"
bpb_key = f"eval/paloma/{domain}/bpb"
# Initialize wandb API
api = wandb.Api()
# Fetch the runs
run1 = api.run(f"{project_name}/{run1_id}")
run2 = api.run(f"{project_name}/{run2_id}")
# Helper to fetch and include '_step'
def fetch_metrics(run, keys):
df = run.history(keys=keys, pandas=True)
return df.reset_index()[['_step'] + keys]
# Drop the initial state row
metrics_llama3 = fetch_metrics(run1, [loss_key, bpb_key])[1:]
metrics_neox = fetch_metrics(run2, [loss_key, bpb_key])[1:]
# Merge on 'step' to align metrics exactly
data = pd.merge(
metrics_llama3,
metrics_neox,
on='_step',
suffixes=('_llama3', '_neox')
)
# Compute differences and the quadratic feature
data['bpb_diff'] = data[f'{bpb_key}_llama3'] - data[f'{bpb_key}_neox']
data['loss_diff'] = data[f'{loss_key}_llama3'] - data[f'{loss_key}_neox']
# data['loss_diff_sq'] = data['loss_diff'] ** 2
# 1) Scatter plot: BPB vs Loss
plt.figure()
plt.scatter(data[f'{loss_key}_llama3'], data[f'{bpb_key}_llama3'], label="Llama3")
plt.scatter(data[f'{loss_key}_neox'], data[f'{bpb_key}_neox'], label="NeoX")
plt.xlabel("Loss")
plt.ylabel("BPB")
plt.title(f"BPB vs Loss for {domain}")
plt.legend()
plt.savefig(f"bpb_vs_loss_{domain}.png")
plt.show()
# 2) Regression with Δloss and (Δloss)^2
X = data[['loss_diff']].values
y = data['bpb_diff'].values
model = LinearRegression()
model.fit(X, y)
print(f"Coefficients: slope={model.coef_[0]:.6f}")
print(f"Intercept: {model.intercept_:.6f}")
print(f"R² score: {model.score(X, y):.6f}")
# 3) Plot Actual vs Predicted ΔBPB
data['predicted_bpb_diff'] = model.predict(X)
plt.figure()
plt.scatter(data['bpb_diff'], data['predicted_bpb_diff'])
plt.xlabel("Actual ΔBPB")
plt.ylabel("Predicted ΔBPB")
plt.title("Actual vs Predicted ΔBPB")
plt.savefig("actual_vs_predicted_bpb_diff.png")
plt.show()
# 4) Save all results
data.to_csv("metrics_differences_with_predictions.csv", index=False)
print("Saved metrics_differences_with_predictions.csv")
# 5) Example standalone prediction
residuals = data['predicted_bpb_diff'] - data['bpb_diff']
sigma = residuals.std()
# 2) do your standalone prediction
llama3_loss = 2.28
neox_loss = 2.43
llama3_bpb = 0.70235
delta = llama3_loss - neox_loss
pred_diff = model.predict([[delta]])[0] # predicted ΔBPB
pred_bpb2 = llama3_bpb - pred_diff # recover BPB_run2
# 3) print with ±σ
print(f"Predicted BPB for run2: {pred_bpb2:.6f} ± {sigma:.6f} BPB")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment