Skip to content

Instantly share code, notes, and snippets.

@albanD
Created September 25, 2019 21:03
Show Gist options
  • Save albanD/6d38f8af58f225cd43ba31e851b011f3 to your computer and use it in GitHub Desktop.
Save albanD/6d38f8af58f225cd43ba31e851b011f3 to your computer and use it in GitHub Desktop.
Compute full Hessian of a network
import torch
from torch import nn
from torchviz import make_dot
from torch.autograd.gradcheck import gradcheck
torch.set_default_tensor_type(torch.DoubleTensor)
my_mod = nn.Sequential(nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 2, bias=False), nn.Sigmoid(), nn.Linear(2, 1, bias=False))
params = list(my_mod.parameters())
print("params")
print(params)
inp = torch.rand(1, 2)
print("inp")
print(inp)
out = my_mod(inp)
print("out")
print(out)
J = torch.autograd.grad(out, params, create_graph=True)
print("J")
print(J)
n_params = 0
basis = []
for p in my_mod.parameters():
n_params += p.nelement()
basis.append(torch.zeros_like(p))
H = torch.zeros(n_params, n_params)
global_idx = 0
for t, grad in zip(basis, J):
n_el = t.nelement()
for i in range(n_el):
t.view(-1).select(0, i).fill_(1)
gradgrad = torch.autograd.grad(grad, params, grad_outputs=t, retain_graph=True, allow_unused=True)
offset = 0
for g in gradgrad:
if g is None:
continue
H.select(0, global_idx).narrow(0, offset, g.nelement()).copy_(g.view(-1))
offset += g.nelement()
global_idx += 1
# Hard to read
# print("H")
# print(H)
print("Diagonal terms of the Hessian for layers:")
offset = 0
for i, p in enumerate(params):
print("For the {}th layer".format(i))
n_el = p.nelement()
print(H.narrow(0, offset, n_el).narrow(1, offset, n_el))
offset += n_el
print("Extradiagonal terms of the Hessian:")
row_offset = 0
for p1_i, p1 in enumerate(params):
p1_nel = p1.nelement()
col_offset = 0
for p2_i, p2 in enumerate(params):
p2_nel = p2.nelement()
if p1 is not p2:
print("For the {}th layer wrt to the {}th layer".format(p1_i, p2_i))
print(H.narrow(0, col_offset, p2_nel).narrow(1, row_offset, p1_nel))
col_offset += p2_nel
row_offset += p1_nel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment