Skip to content

Instantly share code, notes, and snippets.

import torch
import torch.nn as nn
# setup
emb1 = nn.Embedding(4, 4)
opt1 = torch.optim.Adam(emb1.parameters(), lr=1.)
emb2 = nn.Embedding(4, 4, sparse=True)
emb2.load_state_dict(emb1.state_dict())
opt2 = torch.optim.SparseAdam(emb2.parameters(), lr=1.)