Created
February 27, 2025 03:22
-
-
Save stablegradients/4d5ea6c53773c8bd5e118288dc3d213c to your computer and use it in GitHub Desktop.
A function to calculate the hessian
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None): | |
""" | |
Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm. | |
Args: | |
model: The model to compute Hessian spectrum for (linear classifier) | |
feature_model: The feature extraction model | |
data_loader: DataLoader providing input data | |
loss_fn: Loss function | |
num_eigenthings: Number of eigenvalues to compute | |
max_iter: Maximum number of Lanczos iterations | |
tol: Tolerance for convergence | |
output_dir: Directory to save the spectrum plot | |
Returns: | |
eigenvalues: Computed eigenvalues of the Hessian | |
""" | |
logger.info("Computing Hessian spectrum using Lanczos algorithm") | |
# Get model parameters and their count | |
params = [p for p in model.parameters() if p.requires_grad] | |
n_params = sum(p.numel() for p in params) | |
device = params[0].device | |
logger.info(f"Number of parameters: {n_params}") | |
# Function to compute Hessian-vector product | |
def hessian_vector_product(v): | |
# Zero out gradients | |
for p in params: | |
if p.grad is not None: | |
p.grad.zero_() | |
# Compute gradients | |
batch = next(iter(data_loader)) | |
data, labels = batch | |
data = data.to(device) | |
labels = labels.to(device) | |
# Get features from the feature model | |
with torch.no_grad(): | |
features = feature_model(data) | |
# Get model output and loss - use features as input to linear classifier | |
outputs = model(features) | |
if isinstance(outputs, dict): | |
# If model returns multiple outputs, use the first one | |
key = list(outputs.keys())[0] | |
output = outputs[key] | |
else: | |
output = outputs | |
loss = loss_fn(output, labels) | |
# Compute gradients w.r.t loss | |
grads = torch.autograd.grad(loss, params, create_graph=True) | |
# Compute dot product of gradients and vector | |
grad_vector_product = 0 | |
for g, v_chunk in zip(grads, v): | |
grad_vector_product += (g * v_chunk).sum() | |
# Compute Hessian-vector product through second derivative | |
hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True) | |
# Concatenate HVP for return | |
return [h.detach() for h in hvp] | |
# Implement Lanczos algorithm | |
# Initialize random vector | |
v_list = [torch.randn_like(p) for p in params] | |
# Normalize | |
v_norm = torch.sqrt(sum((v * v).sum() for v in v_list)) | |
v_list = [v / v_norm for v in v_list] | |
alpha = torch.zeros(max_iter, device=device) | |
beta = torch.zeros(max_iter-1, device=device) # Fix: beta should have size (max_iter-1) | |
# Initialize the list of vectors for Lanczos | |
q_list = [v_list] | |
q_list.append([torch.zeros_like(p) for p in params]) # q_1 placeholder | |
# Lanczos iterations | |
for i in range(max_iter): | |
logger.info(f"Lanczos iteration {i+1}/{max_iter}") | |
# Compute Hessian-vector product | |
w_list = hessian_vector_product(q_list[i]) | |
# Update alpha | |
alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i])) | |
# Update w | |
w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j] | |
for j, w in enumerate(w_list)] | |
# Early stopping if this is the last iteration | |
if i == max_iter - 1: | |
break | |
# Update beta | |
beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list)) | |
if beta[i] < tol: | |
# Early stopping if converged | |
logger.info(f"Lanczos converged at iteration {i+1}") | |
alpha = alpha[:i+1] | |
beta = beta[:i] | |
break | |
# Update q for next iteration | |
q_list.append([w / beta[i] for w in w_list]) | |
# Construct the tridiagonal matrix | |
T = torch.diag(alpha) | |
# Fix: Only add off-diagonal elements up to the size of beta | |
for i in range(len(beta)): | |
T[i, i+1] = beta[i] | |
T[i+1, i] = beta[i] | |
# Compute eigenvalues of T | |
eigenvalues = torch.linalg.eigvalsh(T) | |
# Sort eigenvalues in descending order | |
eigenvalues = eigenvalues.sort(descending=True)[0] | |
# Take top num_eigenthings | |
if num_eigenthings < len(eigenvalues): | |
eigenvalues = eigenvalues[:num_eigenthings] | |
# Plot the spectrum | |
if output_dir and distributed.is_main_process(): | |
plt.figure(figsize=(10, 6)) | |
plt.stem(eigenvalues.cpu().numpy()) | |
plt.xlabel('Index') | |
plt.ylabel('Eigenvalue') | |
plt.title('Hessian Spectrum') | |
plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png')) | |
plt.close() | |
# Save eigenvalues | |
np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy()) | |
logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}") | |
return eigenvalues |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment