Created
August 30, 2022 23:17
-
-
Save InnovArul/845e909c770e9d15723590ed66f4d6ce to your computer and use it in GitHub Desktop.
using multiple optimizers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, torch.nn as nn | |
import torch.optim as optim | |
def print_grads(modules, string): | |
print(string) | |
for mod in modules: | |
for p in mod.parameters(): | |
print(p.grad) | |
print('**') | |
print("-----") | |
def main(): | |
# model creation (encoder - decoder) | |
enc = nn.Sequential( | |
nn.Linear(3,5), | |
nn.ReLU(inplace=True) | |
) | |
dec = nn.Linear(5, 3) | |
# other network | |
othernet = nn.Sequential( | |
nn.Linear(5,4), | |
nn.ReLU(inplace=True), | |
nn.Linear(4, 2) | |
) | |
# define optimizers for autoencoder and other net | |
autoencoder_optim = optim.SGD(list(enc.parameters()) + list(dec.parameters()), lr=0.0003) | |
othernet_optim = optim.SGD(othernet.parameters(), lr=0.0002) | |
# data for the network | |
data = torch.randn(6, 3) | |
# zero grad | |
autoencoder_optim.zero_grad() | |
othernet_optim.zero_grad() | |
print_grads([enc, dec, othernet], "initial") | |
# model forward | |
bottleneck_out = enc(data) | |
dec_out = dec(bottleneck_out) | |
othernet_out = othernet(bottleneck_out) | |
# calculate autoencoder loss | |
autoencoder_loss = ((dec_out - data)**2).mean() | |
autoencoder_loss.backward(retain_graph=True) | |
print_grads([enc, dec, othernet], "after loss1 backward") | |
# backward othernet loss | |
othernet_loss = othernet_out.mean() # a dummy loss | |
# attach backward hook for bottleneck out | |
lambda_g = 0.02 # ratio of othernet loss for encoder | |
bottleneck_out.register_hook(lambda g: g * lambda_g) | |
othernet_loss.backward() | |
print_grads([enc, dec, othernet], "after loss2 backward") | |
# step the optimizers if needed | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment