Skip to content

Instantly share code, notes, and snippets.

@stablegradients
Created February 27, 2025 03:22

Revisions

  1. stablegradients created this gist Feb 27, 2025.
    148 changes: 148 additions & 0 deletions hessian.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,148 @@

    def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None):
    """
    Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm.
    Args:
    model: The model to compute Hessian spectrum for (linear classifier)
    feature_model: The feature extraction model
    data_loader: DataLoader providing input data
    loss_fn: Loss function
    num_eigenthings: Number of eigenvalues to compute
    max_iter: Maximum number of Lanczos iterations
    tol: Tolerance for convergence
    output_dir: Directory to save the spectrum plot
    Returns:
    eigenvalues: Computed eigenvalues of the Hessian
    """
    logger.info("Computing Hessian spectrum using Lanczos algorithm")

    # Get model parameters and their count
    params = [p for p in model.parameters() if p.requires_grad]
    n_params = sum(p.numel() for p in params)
    device = params[0].device

    logger.info(f"Number of parameters: {n_params}")

    # Function to compute Hessian-vector product
    def hessian_vector_product(v):
    # Zero out gradients
    for p in params:
    if p.grad is not None:
    p.grad.zero_()

    # Compute gradients
    batch = next(iter(data_loader))
    data, labels = batch
    data = data.to(device)
    labels = labels.to(device)

    # Get features from the feature model
    with torch.no_grad():
    features = feature_model(data)

    # Get model output and loss - use features as input to linear classifier
    outputs = model(features)

    if isinstance(outputs, dict):
    # If model returns multiple outputs, use the first one
    key = list(outputs.keys())[0]
    output = outputs[key]
    else:
    output = outputs

    loss = loss_fn(output, labels)

    # Compute gradients w.r.t loss
    grads = torch.autograd.grad(loss, params, create_graph=True)

    # Compute dot product of gradients and vector
    grad_vector_product = 0
    for g, v_chunk in zip(grads, v):
    grad_vector_product += (g * v_chunk).sum()

    # Compute Hessian-vector product through second derivative
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True)

    # Concatenate HVP for return
    return [h.detach() for h in hvp]

    # Implement Lanczos algorithm
    # Initialize random vector
    v_list = [torch.randn_like(p) for p in params]
    # Normalize
    v_norm = torch.sqrt(sum((v * v).sum() for v in v_list))
    v_list = [v / v_norm for v in v_list]

    alpha = torch.zeros(max_iter, device=device)
    beta = torch.zeros(max_iter-1, device=device) # Fix: beta should have size (max_iter-1)

    # Initialize the list of vectors for Lanczos
    q_list = [v_list]
    q_list.append([torch.zeros_like(p) for p in params]) # q_1 placeholder

    # Lanczos iterations
    for i in range(max_iter):
    logger.info(f"Lanczos iteration {i+1}/{max_iter}")

    # Compute Hessian-vector product
    w_list = hessian_vector_product(q_list[i])

    # Update alpha
    alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i]))

    # Update w
    w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j]
    for j, w in enumerate(w_list)]

    # Early stopping if this is the last iteration
    if i == max_iter - 1:
    break

    # Update beta
    beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list))

    if beta[i] < tol:
    # Early stopping if converged
    logger.info(f"Lanczos converged at iteration {i+1}")
    alpha = alpha[:i+1]
    beta = beta[:i]
    break

    # Update q for next iteration
    q_list.append([w / beta[i] for w in w_list])

    # Construct the tridiagonal matrix
    T = torch.diag(alpha)

    # Fix: Only add off-diagonal elements up to the size of beta
    for i in range(len(beta)):
    T[i, i+1] = beta[i]
    T[i+1, i] = beta[i]

    # Compute eigenvalues of T
    eigenvalues = torch.linalg.eigvalsh(T)

    # Sort eigenvalues in descending order
    eigenvalues = eigenvalues.sort(descending=True)[0]

    # Take top num_eigenthings
    if num_eigenthings < len(eigenvalues):
    eigenvalues = eigenvalues[:num_eigenthings]

    # Plot the spectrum
    if output_dir and distributed.is_main_process():
    plt.figure(figsize=(10, 6))
    plt.stem(eigenvalues.cpu().numpy())
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue')
    plt.title('Hessian Spectrum')
    plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png'))
    plt.close()

    # Save eigenvalues
    np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy())

    logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}")
    return eigenvalues