Last active
October 31, 2020 21:06
-
-
Save dantp-ai/42dd51e340668ee3ef320cc69d9bce56 to your computer and use it in GitHub Desktop.
A simple test showing that for the same random seed, code, and neural network weight initialization, the updated weights are the same (PyTorch, v1.6.0)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from copy import deepcopy | |
def trial(seed): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.manual_seed(seed) | |
input_size = 4 | |
output_size = 1 | |
step_size = 0.005 | |
hidden_size_1 = 16 | |
hidden_size_2 = 32 | |
net = nn.Sequential(nn.Linear(input_size, hidden_size_1), nn.ReLU(), | |
nn.Linear(hidden_size_1, hidden_size_2), nn.ReLU(), | |
nn.Linear(hidden_size_2, output_size) | |
).to(device) | |
for m in net: | |
if type(m) == nn.Linear: | |
nn.init.normal_(m.weight, mean=0.0, std=1.0) | |
m.bias.data.fill_(0.01) | |
optimizer = torch.optim.Adam(net.parameters(), lr=step_size) | |
initial_params = [deepcopy(params.data) for params in net.parameters()] | |
num_steps = 10 | |
params = [] | |
for steps in range(num_steps): | |
input = torch.rand(input_size).to(device) | |
output = net(input) | |
loss = output - torch.randint(0, output_size, (1,)).to(device) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
new_params = [deepcopy(params.data) for params in net.parameters()] | |
params.append(new_params) | |
return initial_params, params | |
if __name__ == "__main__": | |
initial_params1, params1 = trial(0) | |
initial_params2, params2 = trial(0) | |
for i in range(len(initial_params1)): | |
assert torch.all(torch.eq(initial_params1[i], initial_params2[i])) | |
for i in range(len(params1)): | |
for j, params in enumerate(params1[i]): | |
assert torch.all(torch.eq(params, params2[i][j])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment