Created
September 2, 2022 05:07
-
-
Save ptrblck/40ac3188f1676b2dc4a1525d747a6a4e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
# setup | |
emb1 = nn.Embedding(4, 4) | |
opt1 = torch.optim.Adam(emb1.parameters(), lr=1.) | |
emb2 = nn.Embedding(4, 4, sparse=True) | |
emb2.load_state_dict(emb1.state_dict()) | |
opt2 = torch.optim.SparseAdam(emb2.parameters(), lr=1.) | |
# 1st update | |
x = torch.tensor([0, 2]) | |
out1 = emb1(x) | |
out1.mean().backward() | |
# gradiets at expected indices | |
print(emb1.weight.grad) | |
opt1.step() | |
opt1.zero_grad() | |
out2 = emb2(x) | |
out2.mean().backward() | |
# gradiets at expected indices | |
print(emb2.weight.grad) | |
opt2.step() | |
opt2.zero_grad() | |
# compare | |
print((emb1.weight - emb2.weight).abs().mean(1)) | |
# tensor([2.3544e-06, 0.0000e+00, 2.3544e-06, 0.0000e+00], | |
# grad_fn=<MeanBackward1>) | |
# small abs differences due to limited floating point precision, but the results are equal | |
# 2nd update at new index | |
x = torch.tensor([1]) | |
out1 = emb1(x) | |
out1.mean().backward() | |
# gradient at expected index | |
print(emb1.weight.grad) | |
opt1.step() | |
opt1.zero_grad() | |
out2 = emb2(x) | |
out2.mean().backward() | |
# gradient at expected index | |
print(emb2.weight.grad) | |
opt2.step() | |
opt2.zero_grad() | |
# compare | |
print((emb1.weight - emb2.weight).abs().mean(1)) | |
# tensor([6.7006e-01, 9.5367e-07, 6.7006e-01, 0.0000e+00], | |
# grad_fn=<MeanBackward1>) | |
# difference now at index 0 and 2 since `Adam` updated it via its running stats | |
# fake updates | |
w1 = emb1.weight.clone() | |
print(emb1.weight - w1) | |
for _ in range(3): | |
# updates it even though the grad is zero | |
opt1.step() | |
print(emb1.weight - w1) | |
w2 = emb2.weight.clone() | |
print(emb2.weight - w2) | |
for _ in range(3): | |
# no updates | |
opt2.step() | |
print(emb2.weight - w2) | |
# now let's use set_to_none=True | |
opt1.zero_grad(set_to_none=True) | |
w1 = emb1.weight.clone() | |
print(emb1.weight - w1) | |
for _ in range(3): | |
# no updates anymore | |
opt1.step() | |
print(emb1.weight - w1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment