Last active
December 6, 2023 04:02
-
-
Save nrimsky/15ed7eb69e330506bb8623e719e94784 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as t | |
from torch.autograd import grad | |
from scipy.sparse.linalg import LinearOperator, eigsh | |
import numpy as np | |
def get_hessian_eigenvectors(model, loss_fn, train_data_loader, num_batches, device, n_top_vectors, param_extract_fn): | |
""" | |
model: a pytorch model | |
loss_fn: a pytorch loss function | |
train_data_loader: a pytorch data loader | |
num_batches: number of batches to use for the hessian calculation | |
device: the device to use for the hessian calculation | |
n_top_vectors: number of top eigenvalues / eigenvectors to return | |
param_extract_fn: a function that takes a model and returns a list of parameters to compute the hessian with respect to (pass None to use all parameters) | |
returns: a tuple of (eigenvalues, eigenvectors) | |
eigenvalues: a numpy array of the top eigenvalues, arranged in increasing order | |
eigenvectors: a numpy array of the top eigenvectors, arranged in increasing order, shape (n_top_vectors, num_params) | |
""" | |
param_extract_fn = param_extract_fn or (lambda x: x.parameters()) | |
num_params = sum(p.numel() for p in param_extract_fn(model)) | |
subset_images, subset_labels = [], [] | |
for batch_idx, (images, labels) in enumerate(train_data_loader): | |
if batch_idx >= num_batches: | |
break | |
subset_images.append(images.to(device)) | |
subset_labels.append(labels.to(device)) | |
subset_images = t.cat(subset_images) | |
subset_labels = t.cat(subset_labels) | |
def compute_loss(): | |
output = model(subset_images) | |
return loss_fn(output, subset_labels) | |
def hessian_vector_product(vector): | |
model.zero_grad() | |
grad_params = grad(compute_loss(), param_extract_fn(model), create_graph=True) | |
flat_grad = t.cat([g.view(-1) for g in grad_params]) | |
grad_vector_product = t.sum(flat_grad * vector) | |
hvp = grad(grad_vector_product, param_extract_fn(model), retain_graph=True) | |
return t.cat([g.contiguous().view(-1) for g in hvp]) | |
def matvec(v): | |
v_tensor = t.tensor(v, dtype=t.float32, device=device) | |
return hessian_vector_product(v_tensor).cpu().detach().numpy() | |
linear_operator = LinearOperator((num_params, num_params), matvec=matvec) | |
eigenvalues, eigenvectors = eigsh(linear_operator, k=n_top_vectors, tol=0.001, which='LM', return_eigenvectors=True) | |
eigenvectors = np.transpose(eigenvectors) | |
return eigenvalues, eigenvectors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment