-
-
Save InnovArul/500e0c57e88300651f8005f9bd0d12bc to your computer and use it in GitHub Desktop.
import torch, torch.nn as nn, torch.nn.functional as F | |
import numpy as np | |
import torch.optim as optim | |
# tied autoencoder using off the shelf nn modules | |
class TiedAutoEncoderOffTheShelf(nn.Module): | |
def __init__(self, inp, out, weight): | |
super().__init__() | |
self.encoder = nn.Linear(inp, out, bias=False) | |
self.decoder = nn.Linear(out, inp, bias=False) | |
# tie the weights | |
#print(type(self.encoder.weight)) | |
self.encoder.weight = nn.Parameter(weight) | |
self.decoder.weight = nn.Parameter(weight.transpose(0,1)) | |
def forward(self, input): | |
encoded_feats = self.encoder(input) | |
reconstructed_output = self.decoder(encoded_feats) | |
return encoded_feats, reconstructed_output | |
# tied auto encoder using functional calls | |
class TiedAutoEncoderFunctional(nn.Module): | |
def __init__(self, inp, out): | |
super().__init__() | |
self.param = nn.Parameter(torch.randn(out, inp)) | |
def forward(self, input): | |
encoded_feats = F.linear(input, self.param) | |
reconstructed_output = F.linear(encoded_feats, self.param.t()) | |
return encoded_feats, reconstructed_output | |
# mixed approach | |
class MixedAppraochTiedAutoEncoder(nn.Module): | |
def __init__(self, inp, out, weight): | |
super().__init__() | |
self.encoder = nn.Linear(inp, out, bias=False) | |
self.encoder.weight = nn.Parameter(weight) | |
def forward(self, input): | |
encoded_feats = self.encoder(input) | |
reconstructed_output = F.linear(encoded_feats, self.encoder.weight.t()) | |
return encoded_feats, reconstructed_output | |
if __name__ == '__main__': | |
tied_module_F = TiedAutoEncoderFunctional(5, 6) | |
# instantiate off-the-shelf auto-encoder | |
offshelf_weight = tied_module_F.param.data.clone() | |
tied_module_offshelf = TiedAutoEncoderOffTheShelf(5, 6, offshelf_weight) | |
# instantiate mixed type auto-encoder | |
mixed_weight = tied_module_F.param.data.clone() | |
tied_module_mixed = MixedAppraochTiedAutoEncoder(5, 6, mixed_weight) | |
assert torch.equal(tied_module_offshelf.encoder.weight.data, tied_module_F.param.data), 'F vs offshelf: param not equal' | |
assert torch.equal(tied_module_mixed.encoder.weight.data, tied_module_F.param.data), 'F vs mixed: param not equal' | |
optim_F = optim.SGD(tied_module_F.parameters(), lr=1) | |
optim_offshelf = optim.SGD(tied_module_offshelf.parameters(), lr=1) | |
optim_mixed = optim.SGD(tied_module_mixed.parameters(), lr=1) | |
# common input | |
input = torch.rand(5, 5) | |
# zero the gradients | |
optim_F.zero_grad() | |
optim_offshelf.zero_grad() | |
optim_mixed.zero_grad() | |
# get output from both modules | |
reconstruction_F = tied_module_F(input) | |
reconstruction_offshelf = tied_module_offshelf(input) | |
reconstruction_mixed = tied_module_mixed(input) | |
# back propagation | |
reconstruction_F[1].sum().backward() | |
reconstruction_offshelf[1].sum().backward() | |
reconstruction_mixed[1].sum().backward() | |
# step | |
optim_F.step() | |
optim_offshelf.step() | |
optim_mixed.step() | |
# check the equality of output and parameters | |
assert torch.equal(reconstruction_offshelf[0], reconstruction_F[0]), 'F vs offshelf: bottleneck not equal' | |
assert torch.equal(reconstruction_offshelf[1], reconstruction_F[1]), 'F vs offshelf: output not equal' | |
assert (tied_module_offshelf.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs offshelf: param after step not equal' | |
assert (tied_module_offshelf.encoder.weight.data - offshelf_weight).pow(2).sum() < 1e-10, 'F vs mixed: source weight tensor not equal' | |
assert torch.equal(reconstruction_mixed[0], reconstruction_F[0]), 'F vs mixed: bottleneck not equal' | |
assert torch.equal(reconstruction_mixed[1], reconstruction_F[1]), 'F vs mixed: output not equal' | |
assert (tied_module_mixed.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
assert (tied_module_mixed.encoder.weight.data - mixed_weight).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
print('success!') |
It seems L14 has to be changed as the following for for PyTorch 1.7.0 to be working:
self.decoder.weight = nn.Parameter(self.encoder.weight.transpose(0,1))
Thanks for letting me know. I have updated the gist to reflect it.
Note for myself: nn.Parameter
shares the underlying memory with the given torch.Tensor
Hello, @InnovArul Thank you for this nice work. But may I ask why can't we just use this approach?
Hello, @InnovArul Thank you for this nice work. But may I ask why can't we just use this approach?
Ideally, the decoder.weight
is the transpose of encoder.weight
.
So, we need decoder.weight = encoder.weight.t()
. However, this will throw the following error: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
.
To overcome that, we need to wrap it with nn.Parameter()
.
Does this answer your question?
Yes. Thank you. Actually, that's what I did. But maybe my question should have been more on are there any merits to using the other approaches you have enlisted here?
To me, tied auto-encoder with functional calls
looks clean without involving nn.Parameter(another_layer.weight)
.
Apart from that, I do not see any particular merits in other approaches.
Hello, @InnovArul Thank you for this nice work! I am currently building an Autoencoder for dimensionality reduction with beginner level of knowledge in PyTorch. Sorry if my question is very trivial, but is the same concept can be applied to a non-linear model? I was thinking of putting gradient=False in the decoder layer so that the model only train the weights for encoder only. Is this a correct approach?
Hi, Sorry that I missed your message. I hope you already found the answer.
Just to answer your question, yes, in my understanding, setting decoder.requires_grad_(False)
would not add the gradient from decoder to the weights. and it will let the weights to only receive gradients from encoder.
It seems L14 has to be changed as the following for for PyTorch 1.7.0 to be working: