Skip to content

Instantly share code, notes, and snippets.

@ngarneau
Created February 20, 2018 15:20
Show Gist options
  • Save ngarneau/f66829cc4c1ff99beb9a2b7c206bf6bd to your computer and use it in GitHub Desktop.
Save ngarneau/f66829cc4c1ff99beb9a2b7c206bf6bd to your computer and use it in GitHub Desktop.
Pytorch Embedding Issue
import torch
import torch.optim as optim
from torch import autograd
from torch.nn import Module, Embedding, Linear, MSELoss, functional as F
from torch.utils.data import TensorDataset, DataLoader
import random
class IssueModule(Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.embedding = Embedding(vocab_size, embedding_dim)
self.linear = Linear(embedding_dim, 1)
def forward(self, x):
x = self.embedding(x)
x = F.max_pool1d(x.transpose(1, 2), len(x[0]))
return self.linear(x.squeeze(-1))
def main():
# Batch size of 1
VOCAB_SIZE = 500000
VOCAB_IDX = [i for i in range(VOCAB_SIZE)]
net = IssueModule(VOCAB_SIZE, 10)
criterion = MSELoss()
toy_sample = torch.LongTensor([[random.choice(VOCAB_IDX) for _ in range(100)]])
# toy_sample = torch.LongTensor([[i for i in range(100)]]) # Take only the first 100 embeddings
toy_pred = autograd.Variable(torch.FloatTensor([[0.1]]))
optimizer = optim.SGD(net.parameters(), lr=0.1)
for _ in range(1000):
net.zero_grad()
pred = net(autograd.Variable(toy_sample))
loss = criterion(pred, toy_pred)
loss.backward()
optimizer.step()
if __name__ == '__main__':
main()
337287 function calls (330550 primitive calls) in 0.717 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.143 0.000 0.143 0.000 {method 'run_backward' of 'torch._C._EngineBase' objects}
16 0.064 0.004 0.070 0.004 {built-in method _imp.create_dynamic}
270 0.048 0.000 0.048 0.000 {built-in method marshal.loads}
819/816 0.022 0.000 0.035 0.000 {built-in method builtins.__build_class__}
3000 0.019 0.000 0.084 0.000 {built-in method apply}
3000 0.014 0.000 0.014 0.000 {method 'add_' of 'torch._C.FloatTensorBase' objects}
337241 function calls (330504 primitive calls) in 1.366 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.543 0.001 0.543 0.001 {method 'run_backward' of 'torch._C._EngineBase' objects}
3000 0.213 0.000 0.213 0.000 {method 'add_' of 'torch._C.FloatTensorBase' objects}
16 0.050 0.003 0.056 0.004 {built-in method _imp.create_dynamic}
3000 0.031 0.000 0.148 0.000 {built-in method apply}
270 0.028 0.000 0.028 0.000 {built-in method marshal.loads}
337218 function calls (330481 primitive calls) in 11.216 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 7.534 0.008 7.534 0.008 {method 'run_backward' of 'torch._C._EngineBase' objects}
3000 2.727 0.001 2.727 0.001 {method 'add_' of 'torch._C.FloatTensorBase' objects}
1 0.264 0.264 0.264 0.264 {method 'normal_' of 'torch._C.FloatTensorBase' objects}
16 0.048 0.003 0.052 0.003 {built-in method _imp.create_dynamic}
3000 0.042 0.000 0.184 0.000 {built-in method apply}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment