Skip to content

Instantly share code, notes, and snippets.

View stablegradients's full-sized avatar

shrinivas stablegradients

View GitHub Profile
@stablegradients
stablegradients / hessian.py
Created February 27, 2025 03:22
A function to calculate the hessian
def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None):
"""
Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm.
Args:
model: The model to compute Hessian spectrum for (linear classifier)
feature_model: The feature extraction model
data_loader: DataLoader providing input data
loss_fn: Loss function