def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None): """ Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm. Args: model: The model to compute Hessian spectrum for (linear classifier) feature_model: The feature extraction model data_loader: DataLoader providing input data loss_fn: Loss function num_eigenthings: Number of eigenvalues to compute max_iter: Maximum number of Lanczos iterations tol: Tolerance for convergence output_dir: Directory to save the spectrum plot Returns: eigenvalues: Computed eigenvalues of the Hessian """ logger.info("Computing Hessian spectrum using Lanczos algorithm") # Get model parameters and their count params = [p for p in model.parameters() if p.requires_grad] n_params = sum(p.numel() for p in params) device = params[0].device logger.info(f"Number of parameters: {n_params}") # Function to compute Hessian-vector product def hessian_vector_product(v): # Zero out gradients for p in params: if p.grad is not None: p.grad.zero_() # Compute gradients batch = next(iter(data_loader)) data, labels = batch data = data.to(device) labels = labels.to(device) # Get features from the feature model with torch.no_grad(): features = feature_model(data) # Get model output and loss - use features as input to linear classifier outputs = model(features) if isinstance(outputs, dict): # If model returns multiple outputs, use the first one key = list(outputs.keys())[0] output = outputs[key] else: output = outputs loss = loss_fn(output, labels) # Compute gradients w.r.t loss grads = torch.autograd.grad(loss, params, create_graph=True) # Compute dot product of gradients and vector grad_vector_product = 0 for g, v_chunk in zip(grads, v): grad_vector_product += (g * v_chunk).sum() # Compute Hessian-vector product through second derivative hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True) # Concatenate HVP for return return [h.detach() for h in hvp] # Implement Lanczos algorithm # Initialize random vector v_list = [torch.randn_like(p) for p in params] # Normalize v_norm = torch.sqrt(sum((v * v).sum() for v in v_list)) v_list = [v / v_norm for v in v_list] alpha = torch.zeros(max_iter, device=device) beta = torch.zeros(max_iter-1, device=device) # Fix: beta should have size (max_iter-1) # Initialize the list of vectors for Lanczos q_list = [v_list] q_list.append([torch.zeros_like(p) for p in params]) # q_1 placeholder # Lanczos iterations for i in range(max_iter): logger.info(f"Lanczos iteration {i+1}/{max_iter}") # Compute Hessian-vector product w_list = hessian_vector_product(q_list[i]) # Update alpha alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i])) # Update w w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j] for j, w in enumerate(w_list)] # Early stopping if this is the last iteration if i == max_iter - 1: break # Update beta beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list)) if beta[i] < tol: # Early stopping if converged logger.info(f"Lanczos converged at iteration {i+1}") alpha = alpha[:i+1] beta = beta[:i] break # Update q for next iteration q_list.append([w / beta[i] for w in w_list]) # Construct the tridiagonal matrix T = torch.diag(alpha) # Fix: Only add off-diagonal elements up to the size of beta for i in range(len(beta)): T[i, i+1] = beta[i] T[i+1, i] = beta[i] # Compute eigenvalues of T eigenvalues = torch.linalg.eigvalsh(T) # Sort eigenvalues in descending order eigenvalues = eigenvalues.sort(descending=True)[0] # Take top num_eigenthings if num_eigenthings < len(eigenvalues): eigenvalues = eigenvalues[:num_eigenthings] # Plot the spectrum if output_dir and distributed.is_main_process(): plt.figure(figsize=(10, 6)) plt.stem(eigenvalues.cpu().numpy()) plt.xlabel('Index') plt.ylabel('Eigenvalue') plt.title('Hessian Spectrum') plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png')) plt.close() # Save eigenvalues np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy()) logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}") return eigenvalues