Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Created September 30, 2018 00:10
Show Gist options
  • Save InnovArul/d66a9e2bf56507dd0803135fb2d910d4 to your computer and use it in GitHub Desktop.
Save InnovArul/d66a9e2bf56507dd0803135fb2d910d4 to your computer and use it in GitHub Desktop.
To validate if the shared module weight is still shared while saving and loading state_dict()
import torch
import torch.nn as nn
class SubModule(nn.Module):
def __init__(self, embedding):
super(SubModule, self).__init__()
self.embedding = embedding
self.fc = nn.Linear(200, 200)
def forward(self, input):
return self.fc(self.embedding(input))
def change_embed_weights(self, num):
nn.init.constant_(self.embedding.weight, num)
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.embedding = nn.Embedding(5, 5)
self.net_a = SubModule(self.embedding)
self.net_b = SubModule(self.embedding)
def forward(self, input):
return self.net_a(input) + self.net_b(input)
def is_embed_weights_equal(self):
return torch.all((self.net_a.embedding.weight-self.net_b.embedding.weight) == 0)
def print_embed_weights(self):
print(self.net_a.embedding.weight)
print(self.net_b.embedding.weight)
m = Model()
print('named params')
for n, p in m.named_parameters():
print(n, p.shape, p.data_ptr)
print('params')
for p in m.parameters():
print(p.shape, p.data_ptr)
print('state dict')
for n, p in m.state_dict().items():
print(n, p.shape, p.data_ptr)
# TO CHANGE WEIGHTS AND CHECK EQUALITY
# before saving
m.net_a.change_embed_weights(1)
m.print_embed_weights()
print(m.is_embed_weights_equal())
# save the model
torch.save(m.state_dict(), 'sample.pth')
new_m = Model()
new_m.load_state_dict(torch.load('sample.pth'))
# before saving
new_m.net_b.change_embed_weights(2)
new_m.print_embed_weights()
print(new_m.is_embed_weights_equal())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment