Skip to content

Instantly share code, notes, and snippets.

@hal3
Last active January 24, 2019 20:56
Show Gist options
  • Save hal3/8c170c4400576eb8d0a8bd94ab231232 to your computer and use it in GitHub Desktop.
Save hal3/8c170c4400576eb8d0a8bd94ab231232 to your computer and use it in GitHub Desktop.
PyTorch implementation of a sequence labeler (POS taggger).
"""
PyTorch implementation of a sequence labeler (POS taggger).
Basic architecture:
- take words
- run though bidirectional GRU
- predict labels one word at a time (left to right), using a recurrent neural network "decoder"
The decoder updates hidden state based on:
- most recent word
- the previous action (aka predicted label).
- the previous hidden state
Can it be faster?!?!?!?!?!?
"""
from __future__ import division
import random
import pickle
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import torch.nn.functional as F
def reseed(seed=90210):
random.seed(seed)
torch.manual_seed(seed)
reseed()
class Example(object):
def __init__(self, tokens, labels, n_labels):
self.tokens = tokens
self.labels = labels
self.n_labels = n_labels
def minibatch(data, minibatch_size, reshuffle):
if reshuffle:
random.shuffle(data)
for n in xrange(0, len(data), minibatch_size):
yield data[n:n+minibatch_size]
def test_wsj():
print
print '# test on wsj subset'
data, n_types, n_labels = pickle.load(open('wsj.pkl', 'r'))
d_emb = 50
d_rnn = 51
d_hid = 52
d_actemb = 5
minibatch_size = 5
n_epochs = 10
preprocess_minibatch = True
embed_word = nn.Embedding(n_types, d_emb)
gru = nn.GRU(d_emb, d_rnn, bidirectional=True)
embed_action = nn.Embedding(n_labels, d_actemb)
combine_arh = nn.Linear(d_actemb + d_rnn * 2 + d_hid, d_hid)
initial_h_tensor = torch.Tensor(1, d_hid)
initial_h_tensor.zero_()
initial_h = Parameter(initial_h_tensor)
initial_actemb_tensor = torch.Tensor(1, d_actemb)
initial_actemb_tensor.zero_()
initial_actemb = Parameter(initial_actemb_tensor)
policy = nn.Linear(d_hid, n_labels)
loss_fn = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.Adam(
list(embed_word.parameters()) +
list(gru.parameters()) +
list(embed_action.parameters()) +
list(combine_arh.parameters()) +
list(policy.parameters()) +
[initial_h, initial_actemb]
, lr=0.01)
for _ in xrange(n_epochs):
total_loss = 0
for batch in minibatch(data, minibatch_size, True):
optimizer.zero_grad()
loss = 0
if preprocess_minibatch:
# for efficiency, combine RNN outputs on entire
# minibatch in one go (requires padding with zeros,
# should be masked but isn't right now)
all_tokens = [ex.tokens for ex in batch]
max_length = max(map(len, all_tokens))
all_tokens = [tok + [0] * (max_length - len(tok)) for tok in all_tokens]
all_e = embed_word(Variable(torch.LongTensor(all_tokens), requires_grad=False))
[all_rnn_out, _] = gru(all_e)
for ex in batch:
N = len(ex.tokens)
if preprocess_minibatch:
rnn_out = all_rnn_out[0,:,:].view(-1, 1, 2 * d_rnn)
else:
e = embed_word(Variable(torch.LongTensor(ex.tokens), requires_grad=False)).view(N, 1, -1)
[rnn_out, _] = gru(e)
prev_h = initial_h # previous hidden state
actemb = initial_actemb # embedding of previous action
output = []
for t in xrange(N):
# update hidden state based on most recent
# *predicted* action (not ground truth)
inputs = [actemb, prev_h, rnn_out[t]]
h = F.relu(combine_arh(torch.cat(inputs, 1)))
# make prediction
pred_vec = policy(h)
pred = pred_vec.data.numpy().argmin()
output.append(pred)
# accumulate loss (squared error against costs)
truth = torch.ones(n_labels)
truth[ex.labels[t]] = 0
loss += loss_fn(pred_vec, Variable(truth, requires_grad=False))
# cache hidden state, previous action embedding
prev_h = h
actemb = embed_action(Variable(torch.LongTensor([pred]), requires_grad=False))
# print 'output=%s, truth=%s' % (output, ex.labels)
loss.backward()
total_loss += loss.data.numpy()[0]
optimizer.step()
print total_loss
if __name__ == '__main__':
test_wsj()
(lp0
(lp1
ccopy_reg
_reconstructor
p2
(c__main__
Example
p3
c__builtin__
object
p4
Ntp5
Rp6
(dp7
S'tokens'
p8
(lp9
I2
aI2
aI18
aI2
aI27
aI2
aI2
aI7
aI21
aI2
aI2
aI2
aI17
asS'labels'
p10
(lp11
I0
aI0
aI1
aI2
aI3
aI0
aI0
aI4
aI5
aI0
aI6
aI2
aI7
asS'n_labels'
p12
I32
sbag2
(g3
g4
Ntp13
Rp14
(dp15
g8
(lp16
I21
aI14
aI2
aI7
aI2
aI7
aI18
aI2
aI2
aI2
aI19
aI2
aI21
aI2
aI7
aI29
aI2
aI2
aI2
aI3
aI19
aI2
aI2
aI11
aI2
aI2
aI2
aI2
aI7
aI2
aI16
aI17
asg10
(lp17
I5
aI2
aI2
aI4
aI2
aI4
aI1
aI8
aI9
aI3
aI10
aI1
aI5
aI11
aI4
aI3
aI8
aI9
aI11
aI12
aI10
aI6
aI11
aI13
aI14
aI15
aI11
aI9
aI4
aI11
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp18
Rp19
(dp20
g8
(lp21
I2
aI2
aI2
aI2
aI2
aI4
aI2
aI6
aI2
aI2
aI7
aI21
aI2
aI2
aI2
aI20
aI2
aI31
aI2
aI2
aI2
aI27
aI2
aI7
aI6
aI2
aI2
aI3
aI2
aI2
aI2
aI3
aI21
aI2
aI17
asg10
(lp22
I3
aI9
aI11
aI16
aI17
aI18
aI3
aI5
aI2
aI3
aI4
aI5
aI19
aI11
aI14
aI3
aI2
aI20
aI0
aI0
aI0
aI3
aI0
aI4
aI5
aI2
aI9
aI12
aI21
aI9
aI2
aI12
aI5
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp23
Rp24
(dp25
g8
(lp26
I6
aI2
aI2
aI16
aI7
aI5
aI2
aI18
aI2
aI2
aI2
aI17
asg10
(lp27
I5
aI0
aI2
aI16
aI4
aI22
aI5
aI1
aI5
aI9
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp28
Rp29
(dp30
g8
(lp31
I2
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI27
aI14
aI2
aI13
aI2
aI2
aI17
asg10
(lp32
I10
aI14
aI6
aI3
aI11
aI3
aI3
aI2
aI16
aI3
aI2
aI6
aI5
aI9
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp33
Rp34
(dp35
g8
(lp36
I2
aI18
aI2
aI14
aI20
aI2
aI2
aI2
aI17
aI30
asg10
(lp37
I23
aI1
aI5
aI2
aI3
aI24
aI11
aI8
aI7
aI25
asg12
I32
sbag2
(g3
g4
Ntp38
Rp39
(dp40
g8
(lp41
I2
aI2
aI2
aI21
aI2
aI2
aI2
aI21
aI22
aI2
aI2
aI27
aI13
aI2
aI26
aI2
aI27
aI21
aI2
aI2
aI17
asg10
(lp42
I5
aI0
aI26
aI5
aI11
aI27
aI16
aI5
aI11
aI16
aI9
aI3
aI5
aI2
aI3
aI11
aI3
aI5
aI0
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp43
Rp44
(dp45
g8
(lp46
I5
aI2
aI2
aI2
aI2
aI2
aI26
aI2
aI2
aI15
aI2
aI2
aI7
aI30
aI16
aI2
aI2
aI2
aI27
aI2
aI31
aI2
aI2
aI2
aI17
asg10
(lp47
I22
aI10
aI14
aI5
aI9
aI2
aI3
aI3
aI11
aI14
aI3
aI2
aI4
aI25
aI16
aI0
aI0
aI0
aI3
aI0
aI20
aI0
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp48
Rp49
(dp50
g8
(lp51
I2
aI2
aI2
aI6
aI2
aI27
aI2
aI2
aI21
aI2
aI2
aI2
aI12
aI21
aI2
aI2
aI27
aI2
aI2
aI12
aI2
aI2
aI17
asg10
(lp52
I0
aI0
aI16
aI5
aI2
aI3
aI11
aI3
aI5
aI0
aI0
aI0
aI26
aI5
aI9
aI11
aI3
aI0
aI0
aI26
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp53
Rp54
(dp55
g8
(lp56
I2
aI27
aI21
aI2
aI2
aI22
aI2
aI2
aI2
aI7
aI2
aI2
aI29
aI2
aI2
aI2
aI17
asg10
(lp57
I28
aI3
aI5
aI28
aI6
aI11
aI14
aI9
aI11
aI4
aI6
aI28
aI3
aI8
aI17
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp58
Rp59
(dp60
g8
(lp61
I5
aI21
aI2
aI2
aI18
aI6
aI2
aI2
aI2
aI2
aI27
aI2
aI2
aI2
aI2
aI2
aI7
aI30
aI16
aI2
aI2
aI17
asg10
(lp62
I22
aI5
aI2
aI2
aI1
aI5
aI9
aI2
aI3
aI5
aI3
aI10
aI27
aI14
aI9
aI11
aI4
aI25
aI16
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp63
Rp64
(dp65
g8
(lp66
I21
aI2
aI27
aI2
aI2
aI2
aI2
aI21
aI22
aI2
aI21
aI2
aI2
aI7
aI2
aI7
aI2
aI8
aI2
aI3
aI2
aI21
aI2
aI25
aI13
aI14
aI22
aI2
aI20
aI2
aI2
aI2
aI7
aI2
aI16
aI17
asg10
(lp67
I5
aI2
aI3
aI2
aI2
aI11
aI3
aI5
aI11
aI3
aI5
aI0
aI0
aI4
aI0
aI4
aI2
aI2
aI1
aI12
aI21
aI5
aI19
aI3
aI5
aI2
aI11
aI17
aI3
aI9
aI17
aI11
aI4
aI10
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp68
Rp69
(dp70
g8
(lp71
I21
aI2
aI7
aI2
aI18
aI2
aI2
aI2
aI2
aI2
aI2
aI7
aI2
aI2
aI2
aI29
aI2
aI3
aI2
aI21
aI2
aI2
aI17
asg10
(lp72
I5
aI2
aI4
aI13
aI1
aI17
aI3
aI0
aI26
aI0
aI0
aI4
aI16
aI3
aI2
aI3
aI2
aI12
aI21
aI5
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp73
Rp74
(dp75
g8
(lp76
I21
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI11
aI21
aI9
aI2
aI2
aI21
aI2
aI27
aI14
aI2
aI2
aI4
aI2
aI2
aI21
aI2
aI2
aI27
aI14
aI7
aI2
aI7
aI2
aI20
aI2
aI2
aI12
aI2
aI2
aI7
aI2
aI2
aI16
aI17
asg10
(lp77
I5
aI2
aI8
aI29
aI21
aI5
aI27
aI14
aI3
aI5
aI0
aI29
aI21
aI5
aI2
aI3
aI2
aI6
aI2
aI18
aI8
aI3
aI5
aI9
aI2
aI3
aI2
aI4
aI2
aI4
aI17
aI3
aI19
aI11
aI26
aI9
aI11
aI4
aI0
aI0
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp78
Rp79
(dp80
g8
(lp81
I21
aI9
aI18
aI2
aI27
aI21
aI2
aI2
aI2
aI11
aI2
aI2
aI2
aI6
aI2
aI2
aI27
aI2
aI25
aI21
aI2
aI7
aI2
aI2
aI2
aI2
aI2
aI11
aI15
aI2
aI2
aI2
aI7
aI2
aI3
aI2
aI2
aI2
aI7
aI6
aI2
aI27
aI2
aI2
aI21
aI2
aI27
aI2
aI2
aI27
aI2
aI17
asg10
(lp82
I5
aI0
aI1
aI28
aI3
aI5
aI9
aI17
aI11
aI13
aI1
aI8
aI21
aI5
aI30
aI2
aI3
aI2
aI3
aI5
aI9
aI4
aI9
aI11
aI9
aI3
aI2
aI13
aI14
aI17
aI3
aI11
aI4
aI6
aI12
aI0
aI0
aI0
aI4
aI5
aI2
aI3
aI2
aI3
aI5
aI0
aI3
aI0
aI0
aI3
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp83
Rp84
(dp85
g8
(lp86
I4
aI2
aI2
aI2
aI15
aI2
aI12
aI15
aI4
aI2
aI2
aI2
aI21
aI2
aI7
aI2
aI2
aI2
aI17
asg10
(lp87
I18
aI9
aI2
aI11
aI14
aI9
aI26
aI14
aI18
aI8
aI17
aI3
aI5
aI2
aI4
aI0
aI0
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp88
Rp89
(dp90
g8
(lp91
I20
aI2
aI7
aI21
aI2
aI2
aI2
aI2
aI6
aI2
aI2
aI26
aI2
aI2
aI2
aI27
aI14
aI17
asg10
(lp92
I3
aI0
aI4
aI5
aI0
aI0
aI0
aI16
aI5
aI9
aI2
aI3
aI8
aI5
aI11
aI3
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp93
Rp94
(dp95
g8
(lp96
I2
aI27
aI21
aI8
aI2
aI2
aI2
aI2
aI21
aI2
aI2
aI2
aI17
asg10
(lp97
I11
aI3
aI5
aI2
aI16
aI8
aI9
aI31
aI5
aI2
aI16
aI17
aI7
asg12
I32
sbag2
(g3
g4
Ntp98
Rp99
(dp100
g8
(lp101
I22
aI2
aI2
aI2
aI2
aI27
aI21
aI2
aI2
aI2
aI6
aI2
aI2
aI7
aI2
aI20
aI2
aI12
aI2
aI2
aI12
aI2
aI2
aI21
aI2
aI2
aI20
aI6
aI2
aI2
aI3
aI2
aI2
aI17
asg10
(lp102
I11
aI16
aI9
aI2
aI11
aI3
aI5
aI17
aI2
aI3
aI5
aI9
aI2
aI4
aI16
aI15
aI2
aI26
aI2
aI11
aI26
aI8
aI16
aI5
aI9
aI11
aI3
aI5
aI2
aI17
aI12
aI21
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp103
Rp104
(dp105
g8
(lp106
I22
aI2
aI5
aI2
aI27
aI2
aI2
aI30
aI11
aI2
aI2
aI2
aI27
aI21
aI8
aI7
aI2
aI2
aI2
aI2
aI2
aI21
aI2
aI17
asg10
(lp107
I11
aI16
aI22
aI11
aI3
aI9
aI2
aI25
aI13
aI16
aI3
aI11
aI3
aI5
aI2
aI4
aI8
aI3
aI2
aI11
aI16
aI5
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp108
Rp109
(dp110
g8
(lp111
I5
aI2
aI31
aI2
aI2
aI11
aI2
aI27
aI2
aI22
aI12
aI28
aI2
aI2
aI2
aI7
aI30
aI16
aI2
aI2
aI7
aI2
aI2
aI27
aI2
aI2
aI25
aI2
aI2
aI2
aI17
asg10
(lp112
I22
aI23
aI1
aI5
aI2
aI3
aI5
aI3
aI5
aI11
aI26
aI11
aI16
aI9
aI11
aI4
aI25
aI16
aI0
aI0
aI4
aI2
aI2
aI3
aI9
aI11
aI3
aI0
aI26
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp113
Rp114
(dp115
g8
(lp116
I19
aI2
aI2
aI2
aI26
aI2
aI2
aI2
aI2
aI17
asg10
(lp117
I10
aI1
aI5
aI2
aI3
aI24
aI2
aI2
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp118
Rp119
(dp120
g8
(lp121
I23
aI26
aI2
aI2
aI2
aI2
aI3
aI2
aI7
aI2
aI2
aI11
aI2
aI28
aI2
aI2
aI2
aI20
aI2
aI10
aI17
asg10
(lp122
I11
aI3
aI9
aI9
aI11
aI16
aI12
aI21
aI4
aI3
aI11
aI3
aI2
aI11
aI14
aI9
aI11
aI3
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp123
Rp124
(dp125
g8
(lp126
I2
aI23
aI2
aI2
aI27
aI2
aI12
aI11
aI21
aI2
aI2
aI2
aI25
aI6
aI2
aI17
asg10
(lp127
I2
aI11
aI14
aI2
aI3
aI11
aI26
aI3
aI5
aI9
aI2
aI1
aI3
aI5
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp128
Rp129
(dp130
g8
(lp131
I2
aI2
aI15
aI2
aI3
aI2
aI2
aI2
aI10
aI24
aI2
aI2
aI2
aI28
aI3
aI2
aI2
aI2
aI10
aI25
aI6
aI2
aI2
aI17
asg10
(lp132
I30
aI11
aI14
aI17
aI12
aI21
aI6
aI2
aI11
aI3
aI10
aI14
aI2
aI11
aI12
aI21
aI8
aI30
aI11
aI3
aI5
aI30
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp133
Rp134
(dp135
g8
(lp136
I2
aI2
aI15
aI2
aI6
aI2
aI27
aI2
aI10
aI24
aI2
aI28
aI2
aI2
aI2
aI10
aI2
aI17
asg10
(lp137
I30
aI11
aI14
aI17
aI5
aI2
aI3
aI6
aI11
aI3
aI2
aI11
aI29
aI21
aI30
aI11
aI8
aI7
asg12
I32
sbag2
(g3
g4
Ntp138
Rp139
(dp140
g8
(lp141
I2
aI7
aI16
aI2
aI2
aI2
aI7
aI2
aI27
aI2
aI2
aI2
aI7
aI23
aI5
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI30
aI24
aI27
aI2
aI2
aI20
aI2
aI2
aI10
aI17
asg10
(lp142
I8
aI4
aI16
aI0
aI0
aI0
aI4
aI2
aI3
aI0
aI0
aI0
aI4
aI11
aI22
aI29
aI21
aI15
aI8
aI3
aI10
aI14
aI15
aI25
aI3
aI3
aI9
aI11
aI3
aI9
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp143
Rp144
(dp145
g8
(lp146
I2
aI2
aI2
aI20
aI23
aI7
aI2
aI2
aI3
aI2
aI2
aI2
aI2
aI2
aI17
asg10
(lp147
I3
aI9
aI11
aI3
aI11
aI4
aI11
aI14
aI12
aI21
aI2
aI3
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp148
Rp149
(dp150
g8
(lp151
I2
aI7
aI2
aI23
aI2
aI2
aI2
aI2
aI24
aI2
aI28
aI2
aI2
aI2
aI12
aI2
aI2
aI21
aI2
aI10
aI17
asg10
(lp152
I8
aI4
aI2
aI11
aI14
aI9
aI9
aI11
aI3
aI2
aI11
aI29
aI21
aI11
aI26
aI21
aI3
aI5
aI19
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp153
Rp154
(dp155
g8
(lp156
I19
aI2
aI2
aI20
aI2
aI2
aI2
aI12
aI18
aI2
aI2
aI2
aI2
aI7
aI2
aI2
aI2
aI2
aI17
asg10
(lp157
I10
aI1
aI8
aI3
aI9
aI11
aI8
aI26
aI1
aI8
aI6
aI2
aI11
aI4
aI13
aI1
aI24
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp158
Rp159
(dp160
g8
(lp161
I2
aI2
aI2
aI2
aI2
aI7
aI2
aI6
aI2
aI2
aI2
aI2
aI7
aI2
aI2
aI17
asg10
(lp162
I10
aI1
aI0
aI0
aI0
aI4
aI8
aI5
aI0
aI0
aI2
aI2
aI4
aI27
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp163
Rp164
(dp165
g8
(lp166
I2
aI2
aI2
aI2
aI27
aI2
aI2
aI31
aI2
aI2
aI2
aI17
asg10
(lp167
I0
aI0
aI1
aI28
aI3
aI0
aI0
aI20
aI28
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp168
Rp169
(dp170
g8
(lp171
I21
aI2
aI2
aI2
aI16
aI19
aI2
aI3
aI2
aI2
aI2
aI12
aI2
aI21
aI2
aI2
aI2
aI17
asg10
(lp172
I5
aI2
aI6
aI2
aI16
aI10
aI1
aI12
aI21
aI9
aI2
aI26
aI21
aI5
aI2
aI3
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp173
Rp174
(dp175
g8
(lp176
I2
aI18
aI2
aI2
aI2
aI2
aI2
aI29
aI2
aI20
aI21
aI2
aI2
aI2
aI17
asg10
(lp177
I0
aI1
aI5
aI9
aI9
aI6
aI2
aI3
aI11
aI3
aI5
aI9
aI2
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp178
Rp179
(dp180
g8
(lp181
I2
aI2
aI7
aI2
aI20
aI2
aI7
aI2
aI7
aI2
aI2
aI2
aI2
aI2
aI17
asg10
(lp182
I0
aI0
aI4
aI17
aI3
aI0
aI4
aI0
aI4
aI1
aI9
aI9
aI11
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp183
Rp184
(dp185
g8
(lp186
I21
aI2
aI2
aI2
aI2
aI27
aI9
aI2
aI2
aI24
aI2
aI2
aI2
aI2
aI21
aI2
aI26
aI2
aI2
aI17
asg10
(lp187
I5
aI9
aI2
aI16
aI11
aI3
aI0
aI11
aI11
aI3
aI0
aI1
aI8
aI17
aI5
aI2
aI3
aI2
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp188
Rp189
(dp190
g8
(lp191
I2
aI2
aI2
aI7
aI21
aI2
aI2
aI2
aI13
aI2
aI3
aI2
aI2
aI2
aI2
aI27
aI13
aI2
aI7
aI21
aI2
aI16
aI17
asg10
(lp192
I3
aI0
aI1
aI4
aI5
aI2
aI1
aI8
aI5
aI2
aI12
aI21
aI9
aI2
aI11
aI3
aI5
aI2
aI4
aI5
aI0
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp193
Rp194
(dp195
g8
(lp196
I2
aI3
aI2
aI21
aI2
aI2
aI18
aI2
aI20
aI21
aI2
aI2
aI2
aI2
aI2
aI17
asg10
(lp197
I2
aI12
aI21
aI5
aI2
aI2
aI1
aI17
aI3
aI5
aI2
aI3
aI6
aI9
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp198
Rp199
(dp200
g8
(lp201
I2
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI12
aI2
aI2
aI27
aI2
aI9
aI2
aI12
aI2
aI2
aI27
aI2
aI2
aI2
aI2
aI2
aI2
aI17
asg10
(lp202
I0
aI0
aI0
aI16
aI17
aI9
aI2
aI2
aI26
aI9
aI2
aI3
aI5
aI0
aI11
aI26
aI2
aI2
aI3
aI9
aI2
aI2
aI0
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp203
Rp204
(dp205
g8
(lp206
I20
aI21
aI2
aI2
aI2
aI2
aI2
aI2
aI31
aI9
aI2
aI7
aI2
aI7
aI2
aI12
aI2
aI2
aI17
asg10
(lp207
I3
aI5
aI9
aI2
aI10
aI29
aI21
aI0
aI20
aI0
aI11
aI4
aI2
aI4
aI11
aI26
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp208
Rp209
(dp210
g8
(lp211
I2
aI19
aI31
aI2
aI25
aI2
aI2
aI2
aI7
aI21
aI2
aI31
aI2
aI2
aI2
aI2
aI2
aI3
aI21
aI2
aI2
aI27
aI2
aI2
aI2
aI2
aI2
aI12
aI2
aI2
aI17
asg10
(lp212
I31
aI10
aI1
aI2
aI3
aI24
aI9
aI2
aI4
aI5
aI2
aI20
aI6
aI11
aI8
aI14
aI15
aI12
aI5
aI9
aI11
aI3
aI2
aI11
aI3
aI0
aI0
aI26
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp213
Rp214
(dp215
g8
(lp216
I2
aI2
aI2
aI17
asg10
(lp217
I8
aI5
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp218
Rp219
(dp220
g8
(lp221
I21
aI2
aI2
aI27
aI2
aI2
aI26
aI21
aI2
aI2
aI27
aI2
aI25
aI2
aI2
aI2
aI2
aI17
asg10
(lp222
I5
aI0
aI0
aI3
aI0
aI16
aI3
aI5
aI0
aI2
aI3
aI0
aI3
aI24
aI2
aI2
aI2
aI7
asg12
I32
sbag2
(g3
g4
Ntp223
Rp224
(dp225
g8
(lp226
I12
aI21
aI2
aI2
aI3
aI2
aI2
aI2
aI4
aI2
aI2
aI2
aI2
aI2
aI2
aI8
aI2
aI17
asg10
(lp227
I26
aI5
aI2
aI16
aI12
aI21
aI24
aI11
aI30
aI3
aI2
aI26
aI2
aI11
aI3
aI2
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp228
Rp229
(dp230
g8
(lp231
I26
aI21
aI2
aI2
aI27
aI21
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI12
aI2
aI7
aI2
aI29
aI2
aI2
aI2
aI2
aI2
aI12
aI21
aI2
aI2
aI2
aI8
aI17
asg10
(lp232
I3
aI5
aI6
aI2
aI3
aI5
aI2
aI16
aI11
aI3
aI11
aI3
aI0
aI0
aI26
aI0
aI4
aI3
aI3
aI30
aI11
aI3
aI0
aI0
aI26
aI5
aI0
aI0
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp233
Rp234
(dp235
g8
(lp236
I25
aI2
aI7
aI21
aI2
aI2
aI2
aI2
aI2
aI2
aI2
aI25
aI2
aI2
aI27
aI21
aI2
aI2
aI2
aI12
aI6
aI2
aI2
aI2
aI2
aI17
asg10
(lp237
I3
aI11
aI4
aI5
aI11
aI16
aI0
aI0
aI0
aI0
aI0
aI3
aI5
aI2
aI3
aI5
aI0
aI0
aI0
aI26
aI5
aI2
aI2
aI0
aI0
aI7
asg12
I32
sbag2
(g3
g4
Ntp238
Rp239
(dp240
g8
(lp241
I2
aI12
aI2
aI2
aI17
asg10
(lp242
I2
aI26
aI2
aI16
aI7
asg12
I32
sbag2
(g3
g4
Ntp243
Rp244
(dp245
g8
(lp246
I21
aI2
aI2
aI7
aI29
aI6
aI2
aI2
aI7
aI2
aI27
aI2
aI12
aI2
aI2
aI2
aI3
aI21
aI2
aI2
aI2
aI7
aI2
aI2
aI2
aI2
aI2
aI2
aI17
asg10
(lp247
I5
aI9
aI2
aI4
aI3
aI5
aI2
aI2
aI4
aI11
aI3
aI11
aI26
aI24
aI11
aI16
aI12
aI5
aI0
aI0
aI0
aI4
aI9
aI3
aI2
aI26
aI9
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp248
Rp249
(dp250
g8
(lp251
I21
aI2
aI2
aI2
aI2
aI19
aI7
aI2
aI21
aI2
aI2
aI2
aI21
aI2
aI2
aI17
asg10
(lp252
I5
aI2
aI29
aI8
aI21
aI10
aI4
aI3
aI5
aI2
aI2
aI16
aI5
aI9
aI11
aI7
asg12
I32
sbag2
(g3
g4
Ntp253
Rp254
(dp255
g8
(lp256
I6
aI2
aI2
aI2
aI2
aI20
aI21
aI2
aI7
aI2
aI2
aI12
aI2
aI15
aI2
aI3
aI2
aI2
aI17
asg10
(lp257
I5
aI2
aI2
aI16
aI17
aI3
aI5
aI2
aI4
aI31
aI2
aI26
aI11
aI14
aI17
aI12
aI9
aI11
aI7
asg12
I32
sbaaI32
aI32
a.
@honnibal
Copy link

honnibal commented Aug 21, 2017

I think you could benefit from some pre-computation here. If I'm understanding correctly, on every word, you're computing:

# update hidden state based on most recent
# *predicted* action (not ground truth)
inputs = [actemb, prev_h, rnn_out[t]]
vector = torch.cat(inputs, 1)
hidden1 = combine_arh(vector)
hidden2 = F.relu(hidden1)

You could instead be doing this inside the inner loop:

hidden1 = W_actemb + W_prev_h + W_rnn_out[t]
hidden2 = F.relu(hidden1)

The variable W_actemb is the dot-product of the action embeddings and the hidden layer. You can compute this at the start of the minibatch, and reuse the computation for each word. This is better with bigger batch sizes, obviously. The only part you need to do inside the inner loop is adding the features active for your state, and then applying the non-linearity.

I've been doing this for spaCy's parser. It's especially good with beam search.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment