-
-
Save davidnvq/c547a72e358bb8baed5f5f04d11d530f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
torch.manual_seed(2809) | |
def check_params(modelA, modelB): | |
for key in modelA.state_dict(): | |
is_equal = (modelA.state_dict()[key]==modelB.state_dict()[key]).all() | |
print('Checking {}, is equal = {}'.format(key, is_equal)) | |
if not is_equal: | |
print('ERROR!') | |
break | |
def check_grads(modelA, modelB): | |
for name, module in modelA.named_parameters(): | |
module_name = name.split('.')[0] | |
param_name = name.split('.')[1] | |
modelB_grad = getattr(getattr(modelB, module_name), param_name).grad | |
is_equal = (module.grad==modelB_grad).all() | |
print('Gradient for {} is equal {}'.format(name, is_equal)) | |
if not is_equal: | |
print('ERROR!') | |
break | |
class MyModel(nn.Module): | |
def __init__(self): | |
super(MyModel, self).__init__() | |
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1) | |
self.pool1 = nn.MaxPool2d(2) | |
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1) | |
self.pool2 = nn.MaxPool2d(2) | |
self.fc = nn.Linear(12*6*6, 2) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = self.pool1(x) | |
x = F.relu(self.conv2(x)) | |
x = self.pool2(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc(x) | |
return x | |
class MyModelUnused(nn.Module): | |
def __init__(self): | |
super(MyModelUnused, self).__init__() | |
self.conv1 = nn.Conv2d(3, 6, 3, 1, 1) | |
self.pool1 = nn.MaxPool2d(2) | |
self.conv2 = nn.Conv2d(6, 12, 3, 1, 1) | |
self.pool2 = nn.MaxPool2d(2) | |
self.fc = nn.Linear(12*6*6, 2) | |
self.conv_unused1 = nn.Conv2d(12, 24, 3, 1, 1) | |
self.conv_unused2 = nn.Conv2d(24, 12, 3, 1, 1) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = self.pool1(x) | |
x = F.relu(self.conv2(x)) | |
x = self.pool2(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc(x) | |
return x | |
x = torch.randn(10, 3, 24, 24) | |
target = torch.empty(10, dtype=torch.long).random_(2) | |
criterion = nn.CrossEntropyLoss() | |
torch.manual_seed(2809) | |
modelA = MyModel() | |
torch.manual_seed(2809) | |
modelB = MyModelUnused() | |
# Check weights for equality | |
check_params(modelA, modelB) | |
optimizerA = optim.Adam(modelA.parameters(), lr=1e-3) | |
optimizerB = optim.Adam(modelB.parameters(), lr=1e-3) | |
for epoch in range(10): | |
print('Checking epoch {}'.format(epoch)) | |
optimizerA.zero_grad() | |
optimizerB.zero_grad() | |
check_params(modelA, modelB) | |
outputA = modelA(x) | |
outputB = modelB(x) | |
(outputA==outputB).all() | |
lossA = criterion(outputA, target) | |
lossB = criterion(outputB, target) | |
(lossA==lossB).all() | |
lossA.backward() | |
lossB.backward() | |
check_grads(modelA, modelB) | |
optimizerA.step() | |
optimizerB.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment