Skip to content

Instantly share code, notes, and snippets.

@zomux
Created July 3, 2015 05:52
Show Gist options
  • Save zomux/9d5a209d25a5c2002479 to your computer and use it in GitHub Desktop.
Save zomux/9d5a209d25a5c2002479 to your computer and use it in GitHub Desktop.
class RNNEncoderDecoder(object):
"""This class encapsulates the translation model.
hack
The expected usage pattern is:
>>> encdec = RNNEncoderDecoder(...)
>>> encdec.build(...)
>>> useful_smth = encdec.create_useful_smth(...)
Functions from the create_smth family (except create_lm_model)
when called complile and return functions that do useful stuff.
"""
def __init__(self, state, rng,
skip_init=False,
compute_alignment=False):
"""Constructor.
:param state:
A state in the usual groundhog sense.
:param rng:
Random number generator. Something like numpy.random.RandomState(seed).
:param skip_init:
If True, all the layers are initialized with zeros. Saves time spent on
parameter initialization if they are loaded later anyway.
:param compute_alignment:
If True, the alignment is returned by the decoder.
"""
self.state = state
self.rng = rng
self.skip_init = skip_init
self.compute_alignment = compute_alignment
def build(self):
logger.debug("Create input variables")
self.x = TT.lmatrix('x')
self.x_mask = TT.matrix('x_mask')
self.y = TT.lmatrix('y')
self.y_mask = TT.matrix('y_mask')
self.inputs = [self.x, self.y, self.x_mask, self.y_mask]
# Annotation for the log-likelihood computation
training_c_components = []
logger.debug("Create encoder")
self.encoder = Encoder(self.state, self.rng,
prefix="enc",
skip_init=self.skip_init)
self.encoder.create_layers()
logger.debug("Build encoding computation graph")
forward_training_c = self.encoder.build_encoder(
self.x, self.x_mask,
use_noise=True,
return_hidden_layers=True)
if self.state['encoder_stack'] > 0:
logger.debug("Create backward encoder")
self.backward_encoder = Encoder(self.state, self.rng,
prefix="back_enc",
skip_init=self.skip_init)
self.backward_encoder.create_layers()
logger.debug("Build backward encoding computation graph")
backward_training_c = self.backward_encoder.build_encoder(
self.x[::-1],
self.x_mask[::-1],
use_noise=True,
approx_embeddings=self.encoder.approx_embedder(self.x[::-1]),
return_hidden_layers=True)
# Reverse time for backward representations.
backward_training_c.out = backward_training_c.out[::-1]
if self.state['encoder_stack'] < 1:
training_c_components.append(forward_training_c)
else:
if self.state['forward']:
training_c_components.append(forward_training_c)
if self.state['last_forward']:
training_c_components.append(
ReplicateLayer(self.x.shape[0])(forward_training_c[-1]))
if self.state['backward']:
training_c_components.append(backward_training_c)
if self.state['last_backward']:
training_c_components.append(ReplicateLayer(self.x.shape[0])
(backward_training_c[0]))
self.state['c_dim'] = len(training_c_components) * self.state['dim']
logger.debug("Create decoder")
self.decoder = Decoder(self.state, self.rng,
skip_init=self.skip_init, compute_alignment=self.compute_alignment)
self.decoder.create_layers()
logger.debug("Build log-likelihood computation graph")
self.predictions, self.alignment = self.decoder.build_decoder(
c=Concatenate(axis=2)(*training_c_components), c_mask=self.x_mask,
y=self.y, y_mask=self.y_mask)
# Annotation for sampling
sampling_c_components = []
logger.debug("Build sampling computation graph")
self.sampling_x = TT.lvector("sampling_x")
self.n_samples = TT.lscalar("n_samples")
self.n_steps = TT.lscalar("n_steps")
self.T = TT.scalar("T")
self.forward_sampling_c = self.encoder.build_encoder(
self.sampling_x,
return_hidden_layers=True).out
if self.state['encoder_stack'] > 0:
self.backward_sampling_c = self.backward_encoder.build_encoder(
self.sampling_x[::-1],
approx_embeddings=self.encoder.approx_embedder(self.sampling_x[::-1]),
return_hidden_layers=True).out[::-1]
if self.state['encoder_stack'] < 1:
sampling_c_components.append(self.forward_sampling_c)
else:
if self.state['forward']:
sampling_c_components.append(self.forward_sampling_c)
if self.state['last_forward']:
sampling_c_components.append(ReplicateLayer(self.sampling_x.shape[0])
(self.forward_sampling_c[-1]))
if self.state['backward']:
sampling_c_components.append(self.backward_sampling_c)
if self.state['last_backward']:
sampling_c_components.append(ReplicateLayer(self.sampling_x.shape[0])
(self.backward_sampling_c[0]))
self.sampling_c = Concatenate(axis=1)(*sampling_c_components).out
(self.sample, self.sample_log_prob), self.sampling_updates =\
self.decoder.build_sampler(self.n_samples, self.n_steps, self.T,
c=self.sampling_c)
logger.debug("Create auxiliary variables")
self.c = TT.matrix("c")
self.step_num = TT.lscalar("step_num")
self.current_states = [TT.matrix("cur_{}".format(i))
for i in range(self.decoder.num_levels)]
self.gen_y = TT.lvector("gen_y")
def create_lm_model(self):
if hasattr(self, 'lm_model'):
return self.lm_model
self.lm_model = LM_Model(
cost_layer=self.predictions,
sample_fn=self.create_sampler(),
weight_noise_amount=self.state['weight_noise_amount'],
indx_word=self.state['indx_word_target'],
indx_word_src=self.state['indx_word'],
rng=self.rng)
self.lm_model.load_dict(self.state)
logger.debug("Model params:\n{}".format(
pprint.pformat(sorted([p.name for p in self.lm_model.params]))))
return self.lm_model
def create_representation_computer(self):
if not hasattr(self, "repr_fn"):
self.repr_fn = theano.function(
inputs=[self.sampling_x],
outputs=[self.sampling_c],
name="repr_fn")
return self.repr_fn
def create_initializers(self):
if not hasattr(self, "init_fn"):
init_c = self.sampling_c[0, -self.state['dim']:]
self.init_fn = theano.function(
inputs=[self.sampling_c],
outputs=self.decoder.build_initializers(init_c),
name="init_fn")
return self.init_fn
def create_sampler(self, many_samples=False):
if hasattr(self, 'sample_fn'):
return self.sample_fn
logger.debug("Compile sampler")
self.sample_fn = theano.function(
inputs=[self.n_samples, self.n_steps, self.T, self.sampling_x],
outputs=[self.sample, self.sample_log_prob],
updates=self.sampling_updates,
name="sample_fn")
if not many_samples:
def sampler(*args):
return map(lambda x : x.squeeze(), self.sample_fn(1, *args))
return sampler
return self.sample_fn
def create_scorer(self, batch=False):
if not hasattr(self, 'score_fn'):
logger.debug("Compile scorer")
self.score_fn = theano.function(
inputs=self.inputs,
outputs=[-self.predictions.cost_per_sample],
name="score_fn")
if batch:
return self.score_fn
def scorer(x, y):
x_mask = numpy.ones(x.shape[0], dtype="float32")
y_mask = numpy.ones(y.shape[0], dtype="float32")
return self.score_fn(x[:, None], y[:, None],
x_mask[:, None], y_mask[:, None])
return scorer
def create_next_probs_computer(self):
if not hasattr(self, 'next_probs_fn'):
self.next_probs_fn = theano.function(
inputs=[self.c, self.step_num, self.gen_y] + self.current_states,
outputs=[self.decoder.build_next_probs_predictor(
self.c, self.step_num, self.gen_y, self.current_states)],
name="next_probs_fn")
return self.next_probs_fn
def create_next_states_computer(self):
if not hasattr(self, 'next_states_fn'):
self.next_states_fn = theano.function(
inputs=[self.c, self.step_num, self.gen_y] + self.current_states,
outputs=self.decoder.build_next_states_computer(
self.c, self.step_num, self.gen_y, self.current_states),
name="next_states_fn")
return self.next_states_fn
def create_probs_computer(self, return_alignment=False):
if not hasattr(self, 'probs_fn'):
logger.debug("Compile probs computer")
self.probs_fn = theano.function(
inputs=self.inputs,
outputs=[self.predictions.word_probs, self.alignment],
name="probs_fn")
def probs_computer(x, y):
x_mask = numpy.ones(x.shape[0], dtype="float32")
y_mask = numpy.ones(y.shape[0], dtype="float32")
probs, alignment = self.probs_fn(x[:, None], y[:, None],
x_mask[:, None], y_mask[:, None])
if return_alignment:
return probs, alignment
else:
return probs
return probs_computer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment