def compute_hessian_spectrum(model, feature_model, data_loader, loss_fn, num_eigenthings=20, max_iter=10, tol=1e-6, output_dir=None):
    """
    Compute the top eigenvalues of the Hessian matrix using the Lanczos algorithm.
    
    Args:
        model: The model to compute Hessian spectrum for (linear classifier)
        feature_model: The feature extraction model
        data_loader: DataLoader providing input data
        loss_fn: Loss function
        num_eigenthings: Number of eigenvalues to compute
        max_iter: Maximum number of Lanczos iterations
        tol: Tolerance for convergence
        output_dir: Directory to save the spectrum plot
        
    Returns:
        eigenvalues: Computed eigenvalues of the Hessian
    """
    logger.info("Computing Hessian spectrum using Lanczos algorithm")
    
    # Get model parameters and their count
    params = [p for p in model.parameters() if p.requires_grad]
    n_params = sum(p.numel() for p in params)
    device = params[0].device
    
    logger.info(f"Number of parameters: {n_params}")
    
    # Function to compute Hessian-vector product
    def hessian_vector_product(v):
        # Zero out gradients
        for p in params:
            if p.grad is not None:
                p.grad.zero_()
        
        # Compute gradients
        batch = next(iter(data_loader))
        data, labels = batch
        data = data.to(device)
        labels = labels.to(device)
        
        # Get features from the feature model
        with torch.no_grad():
            features = feature_model(data)
        
        # Get model output and loss - use features as input to linear classifier
        outputs = model(features)
        
        if isinstance(outputs, dict):
            # If model returns multiple outputs, use the first one
            key = list(outputs.keys())[0]
            output = outputs[key]
        else:
            output = outputs
            
        loss = loss_fn(output, labels)
        
        # Compute gradients w.r.t loss
        grads = torch.autograd.grad(loss, params, create_graph=True)
        
        # Compute dot product of gradients and vector
        grad_vector_product = 0
        for g, v_chunk in zip(grads, v):
            grad_vector_product += (g * v_chunk).sum()
        
        # Compute Hessian-vector product through second derivative
        hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=True)
        
        # Concatenate HVP for return
        return [h.detach() for h in hvp]
    
    # Implement Lanczos algorithm
    # Initialize random vector
    v_list = [torch.randn_like(p) for p in params]
    # Normalize
    v_norm = torch.sqrt(sum((v * v).sum() for v in v_list))
    v_list = [v / v_norm for v in v_list]
    
    alpha = torch.zeros(max_iter, device=device)
    beta = torch.zeros(max_iter-1, device=device)  # Fix: beta should have size (max_iter-1)
    
    # Initialize the list of vectors for Lanczos
    q_list = [v_list]
    q_list.append([torch.zeros_like(p) for p in params])  # q_1 placeholder
    
    # Lanczos iterations
    for i in range(max_iter):
        logger.info(f"Lanczos iteration {i+1}/{max_iter}")
        
        # Compute Hessian-vector product
        w_list = hessian_vector_product(q_list[i])
        
        # Update alpha
        alpha[i] = sum((w * q).sum() for w, q in zip(w_list, q_list[i]))
        
        # Update w
        w_list = [w - alpha[i] * q_list[i][j] - (beta[i-1] if i > 0 else 0) * q_list[i-1][j] 
                 for j, w in enumerate(w_list)]
        
        # Early stopping if this is the last iteration
        if i == max_iter - 1:
            break
            
        # Update beta
        beta[i] = torch.sqrt(sum((w * w).sum() for w in w_list))
        
        if beta[i] < tol:
            # Early stopping if converged
            logger.info(f"Lanczos converged at iteration {i+1}")
            alpha = alpha[:i+1]
            beta = beta[:i]
            break
        
        # Update q for next iteration
        q_list.append([w / beta[i] for w in w_list])
    
    # Construct the tridiagonal matrix
    T = torch.diag(alpha)
    
    # Fix: Only add off-diagonal elements up to the size of beta
    for i in range(len(beta)):
        T[i, i+1] = beta[i]
        T[i+1, i] = beta[i]
    
    # Compute eigenvalues of T
    eigenvalues = torch.linalg.eigvalsh(T)
    
    # Sort eigenvalues in descending order
    eigenvalues = eigenvalues.sort(descending=True)[0]
    
    # Take top num_eigenthings
    if num_eigenthings < len(eigenvalues):
        eigenvalues = eigenvalues[:num_eigenthings]
    
    # Plot the spectrum
    if output_dir and distributed.is_main_process():
        plt.figure(figsize=(10, 6))
        plt.stem(eigenvalues.cpu().numpy())
        plt.xlabel('Index')
        plt.ylabel('Eigenvalue')
        plt.title('Hessian Spectrum')
        plt.savefig(os.path.join(output_dir, 'hessian_spectrum.png'))
        plt.close()
        
        # Save eigenvalues
        np.save(os.path.join(output_dir, 'hessian_eigenvalues.npy'), eigenvalues.cpu().numpy())
    
    logger.info(f"Top-5 eigenvalues: {eigenvalues[:5].cpu().numpy()}")
    return eigenvalues