-
-
Save InnovArul/500e0c57e88300651f8005f9bd0d12bc to your computer and use it in GitHub Desktop.
import torch, torch.nn as nn, torch.nn.functional as F | |
import numpy as np | |
import torch.optim as optim | |
# tied autoencoder using off the shelf nn modules | |
class TiedAutoEncoderOffTheShelf(nn.Module): | |
def __init__(self, inp, out, weight): | |
super().__init__() | |
self.encoder = nn.Linear(inp, out, bias=False) | |
self.decoder = nn.Linear(out, inp, bias=False) | |
# tie the weights | |
#print(type(self.encoder.weight)) | |
self.encoder.weight = nn.Parameter(weight) | |
self.decoder.weight = nn.Parameter(weight.transpose(0,1)) | |
def forward(self, input): | |
encoded_feats = self.encoder(input) | |
reconstructed_output = self.decoder(encoded_feats) | |
return encoded_feats, reconstructed_output | |
# tied auto encoder using functional calls | |
class TiedAutoEncoderFunctional(nn.Module): | |
def __init__(self, inp, out): | |
super().__init__() | |
self.param = nn.Parameter(torch.randn(out, inp)) | |
def forward(self, input): | |
encoded_feats = F.linear(input, self.param) | |
reconstructed_output = F.linear(encoded_feats, self.param.t()) | |
return encoded_feats, reconstructed_output | |
# mixed approach | |
class MixedAppraochTiedAutoEncoder(nn.Module): | |
def __init__(self, inp, out, weight): | |
super().__init__() | |
self.encoder = nn.Linear(inp, out, bias=False) | |
self.encoder.weight = nn.Parameter(weight) | |
def forward(self, input): | |
encoded_feats = self.encoder(input) | |
reconstructed_output = F.linear(encoded_feats, self.encoder.weight.t()) | |
return encoded_feats, reconstructed_output | |
if __name__ == '__main__': | |
tied_module_F = TiedAutoEncoderFunctional(5, 6) | |
# instantiate off-the-shelf auto-encoder | |
offshelf_weight = tied_module_F.param.data.clone() | |
tied_module_offshelf = TiedAutoEncoderOffTheShelf(5, 6, offshelf_weight) | |
# instantiate mixed type auto-encoder | |
mixed_weight = tied_module_F.param.data.clone() | |
tied_module_mixed = MixedAppraochTiedAutoEncoder(5, 6, mixed_weight) | |
assert torch.equal(tied_module_offshelf.encoder.weight.data, tied_module_F.param.data), 'F vs offshelf: param not equal' | |
assert torch.equal(tied_module_mixed.encoder.weight.data, tied_module_F.param.data), 'F vs mixed: param not equal' | |
optim_F = optim.SGD(tied_module_F.parameters(), lr=1) | |
optim_offshelf = optim.SGD(tied_module_offshelf.parameters(), lr=1) | |
optim_mixed = optim.SGD(tied_module_mixed.parameters(), lr=1) | |
# common input | |
input = torch.rand(5, 5) | |
# zero the gradients | |
optim_F.zero_grad() | |
optim_offshelf.zero_grad() | |
optim_mixed.zero_grad() | |
# get output from both modules | |
reconstruction_F = tied_module_F(input) | |
reconstruction_offshelf = tied_module_offshelf(input) | |
reconstruction_mixed = tied_module_mixed(input) | |
# back propagation | |
reconstruction_F[1].sum().backward() | |
reconstruction_offshelf[1].sum().backward() | |
reconstruction_mixed[1].sum().backward() | |
# step | |
optim_F.step() | |
optim_offshelf.step() | |
optim_mixed.step() | |
# check the equality of output and parameters | |
assert torch.equal(reconstruction_offshelf[0], reconstruction_F[0]), 'F vs offshelf: bottleneck not equal' | |
assert torch.equal(reconstruction_offshelf[1], reconstruction_F[1]), 'F vs offshelf: output not equal' | |
assert (tied_module_offshelf.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs offshelf: param after step not equal' | |
assert (tied_module_offshelf.encoder.weight.data - offshelf_weight).pow(2).sum() < 1e-10, 'F vs mixed: source weight tensor not equal' | |
assert torch.equal(reconstruction_mixed[0], reconstruction_F[0]), 'F vs mixed: bottleneck not equal' | |
assert torch.equal(reconstruction_mixed[1], reconstruction_F[1]), 'F vs mixed: output not equal' | |
assert (tied_module_mixed.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
assert (tied_module_mixed.encoder.weight.data - mixed_weight).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
print('success!') |
To me, tied auto-encoder with functional calls
looks clean without involving nn.Parameter(another_layer.weight)
.
Apart from that, I do not see any particular merits in other approaches.
Hello, @InnovArul Thank you for this nice work! I am currently building an Autoencoder for dimensionality reduction with beginner level of knowledge in PyTorch. Sorry if my question is very trivial, but is the same concept can be applied to a non-linear model? I was thinking of putting gradient=False in the decoder layer so that the model only train the weights for encoder only. Is this a correct approach?
Hi, Sorry that I missed your message. I hope you already found the answer.
Just to answer your question, yes, in my understanding, setting decoder.requires_grad_(False)
would not add the gradient from decoder to the weights. and it will let the weights to only receive gradients from encoder.
Yes. Thank you. Actually, that's what I did. But maybe my question should have been more on are there any merits to using the other approaches you have enlisted here?