Skip to content

Instantly share code, notes, and snippets.

@zomux
Created April 9, 2020 21:57
Show Gist options
  • Save zomux/1e4fc0f4cd5df629bfa318d77e8e7523 to your computer and use it in GitHub Desktop.
Save zomux/1e4fc0f4cd5df629bfa318d77e8e7523 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable, grad
import torch.utils.data as Data
import torchvision
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=False,
)
train_loader = Data.DataLoader(
dataset=train_data, batch_size=50, shuffle=True, num_workers=2)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 16, 5, 1, 2))
self.out = nn.Linear(16 * 28 * 28, 10)
def forward(self, x):
x = self.conv1(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output, x
cnn = CNN()
cnn.cuda()
loss_func = nn.CrossEntropyLoss()
for step, (data, label) in enumerate(train_loader):
input = Variable(data).cuda()
target = Variable(label).cuda()
output = cnn(input)[0]
loss = loss_func(output, target)
params = cnn.parameters()
g = grad(loss, params, create_graph=True)
g_sum = 0
for g_para in g:
g_sum += g_para.sum()
params = cnn.parameters()
hv = grad(g_sum, params, create_graph=True)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment