Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created May 14, 2019 12:46
Show Gist options
  • Save ptrblck/e9a7e11384cbd92afea32c95a4e23a73 to your computer and use it in GitHub Desktop.
Save ptrblck/e9a7e11384cbd92afea32c95a4e23a73 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from apex import amp
from torchcontrib.optim import SWA
torch.manual_seed(2809)
device = 'cuda'
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(6)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(12)
self.pool2 = nn.MaxPool2d(2)
self.lin1 = nn.Linear(12*8*8, 32)
self.lin2 = nn.Linear(32, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.bn1(x)
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.bn2(x)
x = self.pool2(x)
x = x.view(x.size(0), -1)
x = F.relu(self.lin1(x))
x = self.lin2(x)
return x
model = MyModel().to(device)
model.bn1.weight.requires_grad_(False)
model.bn1.bias.requires_grad_(False)
model.bn2.weight.requires_grad_(False)
model.bn2.bias.requires_grad_(False)
nb_samples = 10
dataset = TensorDataset(
torch.randn(nb_samples, 3, 32, 32),
torch.randint(0, 10, (nb_samples,))
)
bs = 16
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=2,
shuffle=False,
pin_memory=False
)
optimizer = optim.SGD(model.parameters(), lr=1e-1)
optimizer = SWA(optimizer, swa_start=1, swa_freq=1, swa_lr=0.05)
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
criterion = nn.CrossEntropyLoss()
nb_epochs = 10
for epoch in range(nb_epochs):
for data, target in loader:
old_param = model.lin2.weight.clone()
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
print('Epoch {}, loss {}'.format(epoch, loss.item()))
#print('Param diff {}'.format(torch.abs(old_param - model.lin2.weight).mean()))
print(model.lin2.weight)
optimizer.swap_swa_sgd()
print(model.lin2.weight)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment