Created
February 20, 2018 15:20
-
-
Save ngarneau/f66829cc4c1ff99beb9a2b7c206bf6bd to your computer and use it in GitHub Desktop.
Pytorch Embedding Issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.optim as optim | |
from torch import autograd | |
from torch.nn import Module, Embedding, Linear, MSELoss, functional as F | |
from torch.utils.data import TensorDataset, DataLoader | |
import random | |
class IssueModule(Module): | |
def __init__(self, vocab_size, embedding_dim): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.embedding = Embedding(vocab_size, embedding_dim) | |
self.linear = Linear(embedding_dim, 1) | |
def forward(self, x): | |
x = self.embedding(x) | |
x = F.max_pool1d(x.transpose(1, 2), len(x[0])) | |
return self.linear(x.squeeze(-1)) | |
def main(): | |
# Batch size of 1 | |
VOCAB_SIZE = 500000 | |
VOCAB_IDX = [i for i in range(VOCAB_SIZE)] | |
net = IssueModule(VOCAB_SIZE, 10) | |
criterion = MSELoss() | |
toy_sample = torch.LongTensor([[random.choice(VOCAB_IDX) for _ in range(100)]]) | |
# toy_sample = torch.LongTensor([[i for i in range(100)]]) # Take only the first 100 embeddings | |
toy_pred = autograd.Variable(torch.FloatTensor([[0.1]])) | |
optimizer = optim.SGD(net.parameters(), lr=0.1) | |
for _ in range(1000): | |
net.zero_grad() | |
pred = net(autograd.Variable(toy_sample)) | |
loss = criterion(pred, toy_pred) | |
loss.backward() | |
optimizer.step() | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
337287 function calls (330550 primitive calls) in 0.717 seconds | |
Ordered by: internal time | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1000 0.143 0.000 0.143 0.000 {method 'run_backward' of 'torch._C._EngineBase' objects} | |
16 0.064 0.004 0.070 0.004 {built-in method _imp.create_dynamic} | |
270 0.048 0.000 0.048 0.000 {built-in method marshal.loads} | |
819/816 0.022 0.000 0.035 0.000 {built-in method builtins.__build_class__} | |
3000 0.019 0.000 0.084 0.000 {built-in method apply} | |
3000 0.014 0.000 0.014 0.000 {method 'add_' of 'torch._C.FloatTensorBase' objects} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
337241 function calls (330504 primitive calls) in 1.366 seconds | |
Ordered by: internal time | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1000 0.543 0.001 0.543 0.001 {method 'run_backward' of 'torch._C._EngineBase' objects} | |
3000 0.213 0.000 0.213 0.000 {method 'add_' of 'torch._C.FloatTensorBase' objects} | |
16 0.050 0.003 0.056 0.004 {built-in method _imp.create_dynamic} | |
3000 0.031 0.000 0.148 0.000 {built-in method apply} | |
270 0.028 0.000 0.028 0.000 {built-in method marshal.loads} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
337218 function calls (330481 primitive calls) in 11.216 seconds | |
Ordered by: internal time | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1000 7.534 0.008 7.534 0.008 {method 'run_backward' of 'torch._C._EngineBase' objects} | |
3000 2.727 0.001 2.727 0.001 {method 'add_' of 'torch._C.FloatTensorBase' objects} | |
1 0.264 0.264 0.264 0.264 {method 'normal_' of 'torch._C.FloatTensorBase' objects} | |
16 0.048 0.003 0.052 0.003 {built-in method _imp.create_dynamic} | |
3000 0.042 0.000 0.184 0.000 {built-in method apply} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment