Created
July 3, 2015 05:52
-
-
Save zomux/9d5a209d25a5c2002479 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RNNEncoderDecoder(object): | |
"""This class encapsulates the translation model. | |
hack | |
The expected usage pattern is: | |
>>> encdec = RNNEncoderDecoder(...) | |
>>> encdec.build(...) | |
>>> useful_smth = encdec.create_useful_smth(...) | |
Functions from the create_smth family (except create_lm_model) | |
when called complile and return functions that do useful stuff. | |
""" | |
def __init__(self, state, rng, | |
skip_init=False, | |
compute_alignment=False): | |
"""Constructor. | |
:param state: | |
A state in the usual groundhog sense. | |
:param rng: | |
Random number generator. Something like numpy.random.RandomState(seed). | |
:param skip_init: | |
If True, all the layers are initialized with zeros. Saves time spent on | |
parameter initialization if they are loaded later anyway. | |
:param compute_alignment: | |
If True, the alignment is returned by the decoder. | |
""" | |
self.state = state | |
self.rng = rng | |
self.skip_init = skip_init | |
self.compute_alignment = compute_alignment | |
def build(self): | |
logger.debug("Create input variables") | |
self.x = TT.lmatrix('x') | |
self.x_mask = TT.matrix('x_mask') | |
self.y = TT.lmatrix('y') | |
self.y_mask = TT.matrix('y_mask') | |
self.inputs = [self.x, self.y, self.x_mask, self.y_mask] | |
# Annotation for the log-likelihood computation | |
training_c_components = [] | |
logger.debug("Create encoder") | |
self.encoder = Encoder(self.state, self.rng, | |
prefix="enc", | |
skip_init=self.skip_init) | |
self.encoder.create_layers() | |
logger.debug("Build encoding computation graph") | |
forward_training_c = self.encoder.build_encoder( | |
self.x, self.x_mask, | |
use_noise=True, | |
return_hidden_layers=True) | |
if self.state['encoder_stack'] > 0: | |
logger.debug("Create backward encoder") | |
self.backward_encoder = Encoder(self.state, self.rng, | |
prefix="back_enc", | |
skip_init=self.skip_init) | |
self.backward_encoder.create_layers() | |
logger.debug("Build backward encoding computation graph") | |
backward_training_c = self.backward_encoder.build_encoder( | |
self.x[::-1], | |
self.x_mask[::-1], | |
use_noise=True, | |
approx_embeddings=self.encoder.approx_embedder(self.x[::-1]), | |
return_hidden_layers=True) | |
# Reverse time for backward representations. | |
backward_training_c.out = backward_training_c.out[::-1] | |
if self.state['encoder_stack'] < 1: | |
training_c_components.append(forward_training_c) | |
else: | |
if self.state['forward']: | |
training_c_components.append(forward_training_c) | |
if self.state['last_forward']: | |
training_c_components.append( | |
ReplicateLayer(self.x.shape[0])(forward_training_c[-1])) | |
if self.state['backward']: | |
training_c_components.append(backward_training_c) | |
if self.state['last_backward']: | |
training_c_components.append(ReplicateLayer(self.x.shape[0]) | |
(backward_training_c[0])) | |
self.state['c_dim'] = len(training_c_components) * self.state['dim'] | |
logger.debug("Create decoder") | |
self.decoder = Decoder(self.state, self.rng, | |
skip_init=self.skip_init, compute_alignment=self.compute_alignment) | |
self.decoder.create_layers() | |
logger.debug("Build log-likelihood computation graph") | |
self.predictions, self.alignment = self.decoder.build_decoder( | |
c=Concatenate(axis=2)(*training_c_components), c_mask=self.x_mask, | |
y=self.y, y_mask=self.y_mask) | |
# Annotation for sampling | |
sampling_c_components = [] | |
logger.debug("Build sampling computation graph") | |
self.sampling_x = TT.lvector("sampling_x") | |
self.n_samples = TT.lscalar("n_samples") | |
self.n_steps = TT.lscalar("n_steps") | |
self.T = TT.scalar("T") | |
self.forward_sampling_c = self.encoder.build_encoder( | |
self.sampling_x, | |
return_hidden_layers=True).out | |
if self.state['encoder_stack'] > 0: | |
self.backward_sampling_c = self.backward_encoder.build_encoder( | |
self.sampling_x[::-1], | |
approx_embeddings=self.encoder.approx_embedder(self.sampling_x[::-1]), | |
return_hidden_layers=True).out[::-1] | |
if self.state['encoder_stack'] < 1: | |
sampling_c_components.append(self.forward_sampling_c) | |
else: | |
if self.state['forward']: | |
sampling_c_components.append(self.forward_sampling_c) | |
if self.state['last_forward']: | |
sampling_c_components.append(ReplicateLayer(self.sampling_x.shape[0]) | |
(self.forward_sampling_c[-1])) | |
if self.state['backward']: | |
sampling_c_components.append(self.backward_sampling_c) | |
if self.state['last_backward']: | |
sampling_c_components.append(ReplicateLayer(self.sampling_x.shape[0]) | |
(self.backward_sampling_c[0])) | |
self.sampling_c = Concatenate(axis=1)(*sampling_c_components).out | |
(self.sample, self.sample_log_prob), self.sampling_updates =\ | |
self.decoder.build_sampler(self.n_samples, self.n_steps, self.T, | |
c=self.sampling_c) | |
logger.debug("Create auxiliary variables") | |
self.c = TT.matrix("c") | |
self.step_num = TT.lscalar("step_num") | |
self.current_states = [TT.matrix("cur_{}".format(i)) | |
for i in range(self.decoder.num_levels)] | |
self.gen_y = TT.lvector("gen_y") | |
def create_lm_model(self): | |
if hasattr(self, 'lm_model'): | |
return self.lm_model | |
self.lm_model = LM_Model( | |
cost_layer=self.predictions, | |
sample_fn=self.create_sampler(), | |
weight_noise_amount=self.state['weight_noise_amount'], | |
indx_word=self.state['indx_word_target'], | |
indx_word_src=self.state['indx_word'], | |
rng=self.rng) | |
self.lm_model.load_dict(self.state) | |
logger.debug("Model params:\n{}".format( | |
pprint.pformat(sorted([p.name for p in self.lm_model.params])))) | |
return self.lm_model | |
def create_representation_computer(self): | |
if not hasattr(self, "repr_fn"): | |
self.repr_fn = theano.function( | |
inputs=[self.sampling_x], | |
outputs=[self.sampling_c], | |
name="repr_fn") | |
return self.repr_fn | |
def create_initializers(self): | |
if not hasattr(self, "init_fn"): | |
init_c = self.sampling_c[0, -self.state['dim']:] | |
self.init_fn = theano.function( | |
inputs=[self.sampling_c], | |
outputs=self.decoder.build_initializers(init_c), | |
name="init_fn") | |
return self.init_fn | |
def create_sampler(self, many_samples=False): | |
if hasattr(self, 'sample_fn'): | |
return self.sample_fn | |
logger.debug("Compile sampler") | |
self.sample_fn = theano.function( | |
inputs=[self.n_samples, self.n_steps, self.T, self.sampling_x], | |
outputs=[self.sample, self.sample_log_prob], | |
updates=self.sampling_updates, | |
name="sample_fn") | |
if not many_samples: | |
def sampler(*args): | |
return map(lambda x : x.squeeze(), self.sample_fn(1, *args)) | |
return sampler | |
return self.sample_fn | |
def create_scorer(self, batch=False): | |
if not hasattr(self, 'score_fn'): | |
logger.debug("Compile scorer") | |
self.score_fn = theano.function( | |
inputs=self.inputs, | |
outputs=[-self.predictions.cost_per_sample], | |
name="score_fn") | |
if batch: | |
return self.score_fn | |
def scorer(x, y): | |
x_mask = numpy.ones(x.shape[0], dtype="float32") | |
y_mask = numpy.ones(y.shape[0], dtype="float32") | |
return self.score_fn(x[:, None], y[:, None], | |
x_mask[:, None], y_mask[:, None]) | |
return scorer | |
def create_next_probs_computer(self): | |
if not hasattr(self, 'next_probs_fn'): | |
self.next_probs_fn = theano.function( | |
inputs=[self.c, self.step_num, self.gen_y] + self.current_states, | |
outputs=[self.decoder.build_next_probs_predictor( | |
self.c, self.step_num, self.gen_y, self.current_states)], | |
name="next_probs_fn") | |
return self.next_probs_fn | |
def create_next_states_computer(self): | |
if not hasattr(self, 'next_states_fn'): | |
self.next_states_fn = theano.function( | |
inputs=[self.c, self.step_num, self.gen_y] + self.current_states, | |
outputs=self.decoder.build_next_states_computer( | |
self.c, self.step_num, self.gen_y, self.current_states), | |
name="next_states_fn") | |
return self.next_states_fn | |
def create_probs_computer(self, return_alignment=False): | |
if not hasattr(self, 'probs_fn'): | |
logger.debug("Compile probs computer") | |
self.probs_fn = theano.function( | |
inputs=self.inputs, | |
outputs=[self.predictions.word_probs, self.alignment], | |
name="probs_fn") | |
def probs_computer(x, y): | |
x_mask = numpy.ones(x.shape[0], dtype="float32") | |
y_mask = numpy.ones(y.shape[0], dtype="float32") | |
probs, alignment = self.probs_fn(x[:, None], y[:, None], | |
x_mask[:, None], y_mask[:, None]) | |
if return_alignment: | |
return probs, alignment | |
else: | |
return probs | |
return probs_computer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment