Created
April 1, 2016 20:54
-
-
Save rizar/3271f7efa410872ea5a6f8ac5de58259 to your computer and use it in GitHub Desktop.
Patch required to make reverse_words work with LSTM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/reverse_words/__init__.py b/reverse_words/__init__.py | |
index b649ab5..ac296e7 100644 | |
--- a/reverse_words/__init__.py | |
+++ b/reverse_words/__init__.py | |
@@ -14,7 +14,7 @@ from theano import tensor | |
from blocks.bricks import Tanh, Initializable | |
from blocks.bricks.base import application | |
from blocks.bricks.lookup import LookupTable | |
-from blocks.bricks.recurrent import SimpleRecurrent, Bidirectional | |
+from blocks.bricks.recurrent import SimpleRecurrent, LSTM, Bidirectional | |
from blocks.bricks.attention import SequenceContentAttention | |
from blocks.bricks.parallel import Fork | |
from blocks.bricks.sequence_generators import ( | |
@@ -94,13 +94,13 @@ class WordReverser(Initializable): | |
def __init__(self, dimension, alphabet_size, **kwargs): | |
super(WordReverser, self).__init__(**kwargs) | |
encoder = Bidirectional( | |
- SimpleRecurrent(dim=dimension, activation=Tanh())) | |
+ LSTM(dim=dimension, activation=Tanh())) | |
fork = Fork([name for name in encoder.prototype.apply.sequences | |
if name != 'mask']) | |
fork.input_dim = dimension | |
fork.output_dims = [encoder.prototype.get_dim(name) for name in fork.input_names] | |
lookup = LookupTable(alphabet_size, dimension) | |
- transition = SimpleRecurrent( | |
+ transition = LSTM( | |
activation=Tanh(), | |
dim=dimension, name="transition") | |
attention = SequenceContentAttention( | |
@@ -130,7 +130,7 @@ class WordReverser(Initializable): | |
attended=self.encoder.apply( | |
**dict_union( | |
self.fork.apply(self.lookup.apply(chars), as_dict=True), | |
- mask=chars_mask)), | |
+ mask=chars_mask))[0], | |
attended_mask=chars_mask) | |
@application | |
@@ -139,7 +139,7 @@ class WordReverser(Initializable): | |
n_steps=3 * chars.shape[0], batch_size=chars.shape[1], | |
attended=self.encoder.apply( | |
**dict_union( | |
- self.fork.apply(self.lookup.apply(chars), as_dict=True))), | |
+ self.fork.apply(self.lookup.apply(chars), as_dict=True)))[0], | |
attended_mask=tensor.ones(chars.shape)) | |
@@ -228,7 +228,7 @@ def main(mode, save_path, num_batches, data_path=None): | |
# Construct the main loop and start training! | |
average_monitoring = TrainingDataMonitoring( | |
observables, prefix="average", every_n_batches=10) | |
- | |
+ | |
main_loop = MainLoop( | |
model=model, | |
data_stream=data_stream, | |
(END) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks Dima for sharing your experiences with me. It helped me a lot.
Did you try the sampling or beam search as well? Because I am getting an error during generation of sequences. For example in the reverse_word example, I got a value error from here.