Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created May 15, 2019 17:59
Show Gist options
  • Save ptrblck/c9d7de2dc21b08fe63895250eac83531 to your computer and use it in GitHub Desktop.
Save ptrblck/c9d7de2dc21b08fe63895250eac83531 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from apex import amp
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.lin1 = nn.Linear(10, 10)
self.bn1 = nn.BatchNorm1d(10)
self.lin2 = nn.Linear(10, 10)
self.bn2 = nn.BatchNorm1d(10)
def forward(self, x):
x = F.relu(self.lin1(x))
x = self.bn1(x)
x = F.relu(self.lin2(x))
x = self.bn2(x)
return x
# Create model
device = 'cuda:0'
model = MyModel().to(device)
# Pass only lin1 and bn1 to optimizer
optimizer = optim.Adam([
{'params': model.lin1.parameters()},
{'params': model.bn1.parameters(), 'lr': 1e-0}
], lr=1e-1)
print(optimizer.param_groups)
# Freeze unused parameters
model.lin2.weight.requires_grad_(False)
model.lin2.bias.requires_grad_(False)
model.bn2.weight.requires_grad_(False)
model.bn2.bias.requires_grad_(False)
# Store reference parameters to compare later
state_dict = dict(model.named_parameters())
ref_state_dict = {k: state_dict[k].clone() for k in state_dict}
# Initialize amp
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
# Create random data and criterion
data = torch.randn(10, 10, device=device)
target = torch.randn(10, 10, device=device)
criterion = nn.MSELoss()
# Train for some epochs
nb_epochs = 10
for epoch in range(nb_epochs):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
print(f'Epoch {epoch}, loss {loss.item()}')
# Make sure freezed layer did not accumulate gradients
print('Grad of freezed layer ', model.lin2.weight.grad)
# Get updated parameters and compare with reference (only lin1 and bn1 should be updated)
state_dict = dict(model.named_parameters())
update_state_dict = {k: state_dict[k].clone() for k in state_dict}
for k in ref_state_dict:
print('Diff for {}: {}'.format(
k, torch.abs(ref_state_dict[k] - update_state_dict[k]).mean()))
# Add lin2 to optimizer
optimizer.add_param_group({'params': model.lin2.parameters(), 'lr': 1e-2})
# Unfreeze lin2
model.lin2.weight.requires_grad_(True)
model.lin2.bias.requires_grad_(True)
# Train some more
nb_epochs = 10
for epoch in range(nb_epochs):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
print(f'Epoch {epoch}, loss {loss.item()}')
# Get updated parametes and compare to referece again (lin2 should also be changed by now)
state_dict = dict(model.named_parameters())
update_state_dict = {k: state_dict[k].clone() for k in state_dict}
for k in ref_state_dict:
print('Diff for {}: {}'.format(
k, torch.abs(ref_state_dict[k] - update_state_dict[k]).mean()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment