Skip to content

Instantly share code, notes, and snippets.

@stablegradients
Created February 27, 2025 03:22
Show Gist options
  • Save stablegradients/4d5ea6c53773c8bd5e118288dc3d213c to your computer and use it in GitHub Desktop.
Save stablegradients/4d5ea6c53773c8bd5e118288dc3d213c to your computer and use it in GitHub Desktop.
A function to calculate the hessian
def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None):
"""
Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm.
Args:
model: The model to compute Hessian spectrum for (linear classifier)
feature_model: The feature extraction model
data_loader: DataLoader providing input data
loss_fn: Loss function
num_eigenthings: Number of eigenvalues to compute
max_iter: Maximum number of Lanczos iterations
tol: Tolerance for convergence
output_dir: Directory to save the spectrum plot
Returns:
eigenvalues: Computed eigenvalues of the Hessian
"""
logger.info("Computing Hessian spectrum using Lanczos algorithm")
# Get model parameters and their count
params = [p for p in model.parameters() if p.requires_grad]
n_params = sum(p.numel() for p in params)
device = params[0].device
logger.info(f"Number of parameters: {n_params}")
# Function to compute Hessian-vector product
def hessian_vector_product(v):
# Zero out gradients
for p in params:
if p.grad is not None:
p.grad.zero_()
# Compute gradients
batch = next(iter(data_loader))
data, labels = batch
data = data.to(device)
labels = labels.to(device)
# Get features from the feature model
with torch.no_grad():
features = feature_model(data)
# Get model output and loss - use features as input to linear classifier
outputs = model(features)
if isinstance(outputs, dict):
# If model returns multiple outputs, use the first one
key = list(outputs.keys())[0]
output = outputs[key]
else:
output = outputs
loss = loss_fn(output, labels)
# Compute gradients w.r.t loss
grads = torch.autograd.grad(loss, params, create_graph=True)
# Compute dot product of gradients and vector
grad_vector_product = 0
for g, v_chunk in zip(grads, v):
grad_vector_product += (g * v_chunk).sum()
# Compute Hessian-vector product through second derivative
hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True)
# Concatenate HVP for return
return [h.detach() for h in hvp]
# Implement Lanczos algorithm
# Initialize random vector
v_list = [torch.randn_like(p) for p in params]
# Normalize
v_norm = torch.sqrt(sum((v * v).sum() for v in v_list))
v_list = [v / v_norm for v in v_list]
alpha = torch.zeros(max_iter, device=device)
beta = torch.zeros(max_iter-1, device=device) # Fix: beta should have size (max_iter-1)
# Initialize the list of vectors for Lanczos
q_list = [v_list]
q_list.append([torch.zeros_like(p) for p in params]) # q_1 placeholder
# Lanczos iterations
for i in range(max_iter):
logger.info(f"Lanczos iteration {i+1}/{max_iter}")
# Compute Hessian-vector product
w_list = hessian_vector_product(q_list[i])
# Update alpha
alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i]))
# Update w
w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j]
for j, w in enumerate(w_list)]
# Early stopping if this is the last iteration
if i == max_iter - 1:
break
# Update beta
beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list))
if beta[i] < tol:
# Early stopping if converged
logger.info(f"Lanczos converged at iteration {i+1}")
alpha = alpha[:i+1]
beta = beta[:i]
break
# Update q for next iteration
q_list.append([w / beta[i] for w in w_list])
# Construct the tridiagonal matrix
T = torch.diag(alpha)
# Fix: Only add off-diagonal elements up to the size of beta
for i in range(len(beta)):
T[i, i+1] = beta[i]
T[i+1, i] = beta[i]
# Compute eigenvalues of T
eigenvalues = torch.linalg.eigvalsh(T)
# Sort eigenvalues in descending order
eigenvalues = eigenvalues.sort(descending=True)[0]
# Take top num_eigenthings
if num_eigenthings < len(eigenvalues):
eigenvalues = eigenvalues[:num_eigenthings]
# Plot the spectrum
if output_dir and distributed.is_main_process():
plt.figure(figsize=(10, 6))
plt.stem(eigenvalues.cpu().numpy())
plt.xlabel('Index')
plt.ylabel('Eigenvalue')
plt.title('Hessian Spectrum')
plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png'))
plt.close()
# Save eigenvalues
np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy())
logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}")
return eigenvalues
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment