Skip to content

Instantly share code, notes, and snippets.

@r9y9
Created June 21, 2019 02:07
Show Gist options
  • Save r9y9/87a2b54a68daadfbe959fa558e82bb94 to your computer and use it in GitHub Desktop.
Save r9y9/87a2b54a68daadfbe959fa558e82bb94 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
torch.manual_seed(1234)
model1 = nn.Sequential(*[nn.Linear(1, 1) for _ in range(2)])
layer = nn.Linear(1, 1)
model2 = nn.Sequential(*[layer for _ in range(2)])
print("Model1 (two different linear layers):")
assert not torch.equal(model1[0].weight, model1[1].weight)
for k, v in model1.state_dict().items():
print(k, v)
print("Model2 (two equivalent linear layers):")
assert torch.equal(model2[0].weight, model2[1].weight)
for k, v in model2.state_dict().items():
print(k, v)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment