Created
May 15, 2019 17:59
-
-
Save ptrblck/c9d7de2dc21b08fe63895250eac83531 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from apex import amp | |
class MyModel(nn.Module): | |
def __init__(self): | |
super(MyModel, self).__init__() | |
self.lin1 = nn.Linear(10, 10) | |
self.bn1 = nn.BatchNorm1d(10) | |
self.lin2 = nn.Linear(10, 10) | |
self.bn2 = nn.BatchNorm1d(10) | |
def forward(self, x): | |
x = F.relu(self.lin1(x)) | |
x = self.bn1(x) | |
x = F.relu(self.lin2(x)) | |
x = self.bn2(x) | |
return x | |
# Create model | |
device = 'cuda:0' | |
model = MyModel().to(device) | |
# Pass only lin1 and bn1 to optimizer | |
optimizer = optim.Adam([ | |
{'params': model.lin1.parameters()}, | |
{'params': model.bn1.parameters(), 'lr': 1e-0} | |
], lr=1e-1) | |
print(optimizer.param_groups) | |
# Freeze unused parameters | |
model.lin2.weight.requires_grad_(False) | |
model.lin2.bias.requires_grad_(False) | |
model.bn2.weight.requires_grad_(False) | |
model.bn2.bias.requires_grad_(False) | |
# Store reference parameters to compare later | |
state_dict = dict(model.named_parameters()) | |
ref_state_dict = {k: state_dict[k].clone() for k in state_dict} | |
# Initialize amp | |
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') | |
# Create random data and criterion | |
data = torch.randn(10, 10, device=device) | |
target = torch.randn(10, 10, device=device) | |
criterion = nn.MSELoss() | |
# Train for some epochs | |
nb_epochs = 10 | |
for epoch in range(nb_epochs): | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
print(f'Epoch {epoch}, loss {loss.item()}') | |
# Make sure freezed layer did not accumulate gradients | |
print('Grad of freezed layer ', model.lin2.weight.grad) | |
# Get updated parameters and compare with reference (only lin1 and bn1 should be updated) | |
state_dict = dict(model.named_parameters()) | |
update_state_dict = {k: state_dict[k].clone() for k in state_dict} | |
for k in ref_state_dict: | |
print('Diff for {}: {}'.format( | |
k, torch.abs(ref_state_dict[k] - update_state_dict[k]).mean())) | |
# Add lin2 to optimizer | |
optimizer.add_param_group({'params': model.lin2.parameters(), 'lr': 1e-2}) | |
# Unfreeze lin2 | |
model.lin2.weight.requires_grad_(True) | |
model.lin2.bias.requires_grad_(True) | |
# Train some more | |
nb_epochs = 10 | |
for epoch in range(nb_epochs): | |
optimizer.zero_grad() | |
output = model(data) | |
loss = criterion(output, target) | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
print(f'Epoch {epoch}, loss {loss.item()}') | |
# Get updated parametes and compare to referece again (lin2 should also be changed by now) | |
state_dict = dict(model.named_parameters()) | |
update_state_dict = {k: state_dict[k].clone() for k in state_dict} | |
for k in ref_state_dict: | |
print('Diff for {}: {}'.format( | |
k, torch.abs(ref_state_dict[k] - update_state_dict[k]).mean())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment