Created
February 27, 2025 03:22
Revisions
-
stablegradients created this gist
Feb 27, 2025 .There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,148 @@ def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None): """ Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm. Args: model: The model to compute Hessian spectrum for (linear classifier) feature_model: The feature extraction model data_loader: DataLoader providing input data loss_fn: Loss function num_eigenthings: Number of eigenvalues to compute max_iter: Maximum number of Lanczos iterations tol: Tolerance for convergence output_dir: Directory to save the spectrum plot Returns: eigenvalues: Computed eigenvalues of the Hessian """ logger.info("Computing Hessian spectrum using Lanczos algorithm") # Get model parameters and their count params = [p for p in model.parameters() if p.requires_grad] n_params = sum(p.numel() for p in params) device = params[0].device logger.info(f"Number of parameters: {n_params}") # Function to compute Hessian-vector product def hessian_vector_product(v): # Zero out gradients for p in params: if p.grad is not None: p.grad.zero_() # Compute gradients batch = next(iter(data_loader)) data, labels = batch data = data.to(device) labels = labels.to(device) # Get features from the feature model with torch.no_grad(): features = feature_model(data) # Get model output and loss - use features as input to linear classifier outputs = model(features) if isinstance(outputs, dict): # If model returns multiple outputs, use the first one key = list(outputs.keys())[0] output = outputs[key] else: output = outputs loss = loss_fn(output, labels) # Compute gradients w.r.t loss grads = torch.autograd.grad(loss, params, create_graph=True) # Compute dot product of gradients and vector grad_vector_product = 0 for g, v_chunk in zip(grads, v): grad_vector_product += (g * v_chunk).sum() # Compute Hessian-vector product through second derivative hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True) # Concatenate HVP for return return [h.detach() for h in hvp] # Implement Lanczos algorithm # Initialize random vector v_list = [torch.randn_like(p) for p in params] # Normalize v_norm = torch.sqrt(sum((v * v).sum() for v in v_list)) v_list = [v / v_norm for v in v_list] alpha = torch.zeros(max_iter, device=device) beta = torch.zeros(max_iter-1, device=device) # Fix: beta should have size (max_iter-1) # Initialize the list of vectors for Lanczos q_list = [v_list] q_list.append([torch.zeros_like(p) for p in params]) # q_1 placeholder # Lanczos iterations for i in range(max_iter): logger.info(f"Lanczos iteration {i+1}/{max_iter}") # Compute Hessian-vector product w_list = hessian_vector_product(q_list[i]) # Update alpha alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i])) # Update w w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j] for j, w in enumerate(w_list)] # Early stopping if this is the last iteration if i == max_iter - 1: break # Update beta beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list)) if beta[i] < tol: # Early stopping if converged logger.info(f"Lanczos converged at iteration {i+1}") alpha = alpha[:i+1] beta = beta[:i] break # Update q for next iteration q_list.append([w / beta[i] for w in w_list]) # Construct the tridiagonal matrix T = torch.diag(alpha) # Fix: Only add off-diagonal elements up to the size of beta for i in range(len(beta)): T[i, i+1] = beta[i] T[i+1, i] = beta[i] # Compute eigenvalues of T eigenvalues = torch.linalg.eigvalsh(T) # Sort eigenvalues in descending order eigenvalues = eigenvalues.sort(descending=True)[0] # Take top num_eigenthings if num_eigenthings < len(eigenvalues): eigenvalues = eigenvalues[:num_eigenthings] # Plot the spectrum if output_dir and distributed.is_main_process(): plt.figure(figsize=(10, 6)) plt.stem(eigenvalues.cpu().numpy()) plt.xlabel('Index') plt.ylabel('Eigenvalue') plt.title('Hessian Spectrum') plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png')) plt.close() # Save eigenvalues np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy()) logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}") return eigenvalues