Last active
September 22, 2021 14:53
-
-
Save ptrblck/d9abccd4f52b1aa6d242da3338533169 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
use_adam = False | |
class MyModel(nn.Module): | |
def __init__(self): | |
super(MyModel, self).__init__() | |
self.enc = nn.Linear(64, 10) | |
self.dec1 = nn.Linear(10, 64) | |
self.dec2 = nn.Linear(10, 64) | |
def forward(self, x, decoder_idx): | |
x = F.relu(self.enc(x)) | |
if decoder_idx == 1: | |
print('Using dec1') | |
x = self.dec1(x) | |
elif decoder_idx == 2: | |
print('Using dec2') | |
x = self.dec2(x) | |
else: | |
print('Unknown decoder_idx') | |
return x | |
x = torch.randn(1, 64) | |
y = x.clone() | |
model = MyModel() | |
criterion = nn.MSELoss() | |
if use_adam: | |
optimizer = optim.Adam(model.parameters(), lr=1.) | |
else: | |
optimizer = optim.SGD(model.parameters(), lr=1.) | |
# Save init values | |
old_state_dict = {} | |
for key in model.state_dict(): | |
old_state_dict[key] = model.state_dict()[key].clone() | |
# Training procedure | |
optimizer.zero_grad() | |
output = model(x, 1) | |
loss = criterion(output, y) | |
loss.backward() | |
# Check for gradients in dec1, dec2 | |
print('Dec1 grad: {}\nDec2 grad: {}'.format( | |
model.dec1.weight.grad, model.dec2.weight.grad)) | |
optimizer.step() | |
# Save new params | |
new_state_dict = {} | |
for key in model.state_dict(): | |
new_state_dict[key] = model.state_dict()[key].clone() | |
# Compare params | |
for key in old_state_dict: | |
if not (old_state_dict[key] == new_state_dict[key]).all(): | |
print('Diff in {}'.format(key)) | |
# Update | |
old_state_dict = {} | |
for key in model.state_dict(): | |
old_state_dict[key] = model.state_dict()[key].clone() | |
# Pass through dec2 | |
optimizer.zero_grad() | |
output = model(x, 2) | |
loss = criterion(output, y) | |
loss.backward() | |
print('Dec1 grad: {}\nDec2 grad: {}'.format( | |
model.dec1.weight.grad, model.dec2.weight.grad)) | |
optimizer.step() | |
# Save new params | |
new_state_dict = {} | |
for key in model.state_dict(): | |
new_state_dict[key] = model.state_dict()[key].clone() | |
# Compare params | |
for key in old_state_dict: | |
if not (old_state_dict[key] == new_state_dict[key]).all(): | |
print('Diff in {}'.format(key)) | |
## Create separate optimizers | |
model = MyModel() | |
dec1_params = list(model.enc.parameters()) + list(model.dec1.parameters()) | |
optimizer1 = optim.Adam(dec1_params, lr=1.) | |
dec2_params = list(model.enc.parameters()) + list(model.dec2.parameters()) | |
optimizer2 = optim.Adam(dec2_params, lr=1.) | |
# Save init values | |
old_state_dict = {} | |
for key in model.state_dict(): | |
old_state_dict[key] = model.state_dict()[key].clone() | |
# Training procedure | |
optimizer1.zero_grad() | |
output = model(x, 1) | |
loss = criterion(output, y) | |
loss.backward() | |
# Check for gradients in dec1, dec2 | |
print('Dec1 grad: {}\nDec2 grad: {}'.format( | |
model.dec1.weight.grad, model.dec2.weight.grad)) | |
optimizer1.step() | |
# Save new params | |
new_state_dict = {} | |
for key in model.state_dict(): | |
new_state_dict[key] = model.state_dict()[key].clone() | |
# Compare params | |
for key in old_state_dict: | |
if not (old_state_dict[key] == new_state_dict[key]).all(): | |
print('Diff in {}'.format(key)) | |
# Update | |
old_state_dict = {} | |
for key in model.state_dict(): | |
old_state_dict[key] = model.state_dict()[key].clone() | |
# Pass through dec2 | |
optimizer1.zero_grad() | |
output = model(x, 2) | |
loss = criterion(output, y) | |
loss.backward() | |
print('Dec1 grad: {}\nDec2 grad: {}'.format( | |
model.dec1.weight.grad, model.dec2.weight.grad)) | |
optimizer2.step() | |
# Save new params | |
new_state_dict = {} | |
for key in model.state_dict(): | |
new_state_dict[key] = model.state_dict()[key].clone() | |
# Compare params | |
for key in old_state_dict: | |
if not (old_state_dict[key] == new_state_dict[key]).all(): | |
print('Diff in {}'.format(key)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment