Created
July 3, 2015 05:56
-
-
Save zomux/8d414a561cc68e0763fd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(EncoderDecoderBase): | |
EVALUATION = 0 | |
SAMPLING = 1 | |
BEAM_SEARCH = 2 | |
def __init__(self, state, rng, prefix='dec', | |
skip_init=False, compute_alignment=False): | |
self.state = state | |
self.rng = rng | |
self.prefix = prefix | |
self.skip_init = skip_init | |
self.compute_alignment = compute_alignment | |
# Actually there is a problem here - | |
# we don't make difference between number of input layers | |
# and outputs layers. | |
self.num_levels = self.state['decoder_stack'] | |
if 'dim_mult' not in self.state: | |
self.state['dim_mult'] = 1. | |
def create_layers(self): | |
""" Create all elements of Decoder's computation graph""" | |
self.default_kwargs = dict( | |
init_fn=self.state['weight_init_fn'] if not self.skip_init else "sample_zeros", | |
weight_noise=self.state['weight_noise'], | |
scale=self.state['weight_scale']) | |
self._create_embedding_layers() | |
self._create_transition_layers() | |
self._create_inter_level_layers() | |
self._create_initialization_layers() | |
self._create_decoding_layers() | |
self._create_readout_layers() | |
if self.state['search']: | |
assert self.num_levels == 1 | |
self.transitions[0].set_decoding_layers( | |
self.decode_inputers[0], | |
self.decode_reseters[0], | |
self.decode_updaters[0]) | |
def _create_initialization_layers(self): | |
logger.debug("_create_initialization_layers") | |
self.initializers = [ZeroLayer()] * self.num_levels | |
if self.state['bias_code']: | |
for level in range(self.num_levels): | |
self.initializers[level] = MultiLayer( | |
self.rng, | |
n_in=self.state['dim'], | |
n_hids=[self.state['dim'] * self.state['hid_mult']], | |
activation=[prefix_lookup(self.state, 'dec', 'activ')], | |
bias_scale=[self.state['bias']], | |
name='{}_initializer_{}'.format(self.prefix, level), | |
**self.default_kwargs) | |
def _create_decoding_layers(self): | |
logger.debug("_create_decoding_layers") | |
self.decode_inputers = [lambda x : 0] * self.num_levels | |
self.decode_reseters = [lambda x : 0] * self.num_levels | |
self.decode_updaters = [lambda x : 0] * self.num_levels | |
self.back_decode_inputers = [lambda x : 0] * self.num_levels | |
self.back_decode_reseters = [lambda x : 0] * self.num_levels | |
self.back_decode_updaters = [lambda x : 0] * self.num_levels | |
decoding_kwargs = dict(self.default_kwargs) | |
decoding_kwargs.update(dict( | |
n_in=self.state['c_dim'], | |
n_hids=self.state['dim'] * self.state['dim_mult'], | |
activation=['lambda x:x'], | |
learn_bias=False)) | |
if self.state['decoding_inputs']: | |
for level in range(self.num_levels): | |
# Input contributions | |
self.decode_inputers[level] = MultiLayer( | |
self.rng, | |
name='{}_dec_inputter_{}'.format(self.prefix, level), | |
**decoding_kwargs) | |
# Update gate contributions | |
if prefix_lookup(self.state, 'dec', 'rec_gating'): | |
self.decode_updaters[level] = MultiLayer( | |
self.rng, | |
name='{}_dec_updater_{}'.format(self.prefix, level), | |
**decoding_kwargs) | |
# Reset gate contributions | |
if prefix_lookup(self.state, 'dec', 'rec_reseting'): | |
self.decode_reseters[level] = MultiLayer( | |
self.rng, | |
name='{}_dec_reseter_{}'.format(self.prefix, level), | |
**decoding_kwargs) | |
def _create_readout_layers(self): | |
softmax_layer = self.state['softmax_layer'] if 'softmax_layer' in self.state \ | |
else 'SoftmaxLayer' | |
logger.debug("_create_readout_layers") | |
readout_kwargs = dict(self.default_kwargs) | |
readout_kwargs.update(dict( | |
n_hids=self.state['dim'], | |
activation='lambda x: x', | |
)) | |
self.repr_readout = MultiLayer( | |
self.rng, | |
n_in=self.state['c_dim'], | |
learn_bias=False, | |
name='{}_repr_readout'.format(self.prefix), | |
**readout_kwargs) | |
# Attention - this is the only readout layer | |
# with trainable bias. Should be careful with that. | |
self.hidden_readouts = [None] * self.num_levels | |
for level in range(self.num_levels): | |
self.hidden_readouts[level] = MultiLayer( | |
self.rng, | |
n_in=self.state['dim'], | |
name='{}_hid_readout_{}'.format(self.prefix, level), | |
**readout_kwargs) | |
self.prev_word_readout = 0 | |
if self.state['bigram']: | |
self.prev_word_readout = MultiLayer( | |
self.rng, | |
n_in=self.state['rank_n_approx'], | |
n_hids=self.state['dim'], | |
activation=['lambda x:x'], | |
learn_bias=False, | |
name='{}_prev_readout_{}'.format(self.prefix, level), | |
**self.default_kwargs) | |
if self.state['deep_out']: | |
act_layer = UnaryOp(activation=eval(self.state['unary_activ'])) | |
drop_layer = DropOp(rng=self.rng, dropout=self.state['dropout']) | |
self.output_nonlinearities = [act_layer, drop_layer] | |
self.output_layer = eval(softmax_layer)( | |
self.rng, | |
self.state['dim'] / self.state['maxout_part'], | |
self.state['n_sym_target'], | |
sparsity=-1, | |
rank_n_approx=self.state['rank_n_approx'], | |
name='{}_deep_softmax'.format(self.prefix), | |
use_nce=self.state['use_nce'] if 'use_nce' in self.state else False, | |
**self.default_kwargs) | |
else: | |
self.output_nonlinearities = [] | |
self.output_layer = eval(softmax_layer)( | |
self.rng, | |
self.state['dim'], | |
self.state['n_sym_target'], | |
sparsity=-1, | |
rank_n_approx=self.state['rank_n_approx'], | |
name='dec_softmax', | |
sum_over_time=True, | |
use_nce=self.state['use_nce'] if 'use_nce' in self.state else False, | |
**self.default_kwargs) | |
def build_decoder(self, c, y, | |
c_mask=None, | |
y_mask=None, | |
step_num=None, | |
mode=EVALUATION, | |
given_init_states=None, | |
T=1): | |
"""Create the computational graph of the RNN Decoder. | |
:param c: | |
representations produced by an encoder. | |
(n_samples, dim) matrix if mode == sampling or | |
(max_seq_len, batch_size, dim) matrix if mode == evaluation | |
:param c_mask: | |
if mode == evaluation a 0/1 matrix identifying valid positions in c | |
:param y: | |
if mode == evaluation | |
target sequences, matrix of word indices of shape (max_seq_len, batch_size), | |
where each column is a sequence | |
if mode != evaluation | |
a vector of previous words of shape (n_samples,) | |
:param y_mask: | |
if mode == evaluation a 0/1 matrix determining lengths | |
of the target sequences, must be None otherwise | |
:param mode: | |
chooses on of three modes: evaluation, sampling and beam_search | |
:param given_init_states: | |
for sampling and beam_search. A list of hidden states | |
matrices for each layer, each matrix is (n_samples, dim) | |
:param T: | |
sampling temperature | |
""" | |
# Check parameter consistency | |
if mode == Decoder.EVALUATION: | |
assert not given_init_states | |
else: | |
assert not y_mask | |
assert given_init_states | |
if mode == Decoder.BEAM_SEARCH: | |
assert T == 1 | |
# For log-likelihood evaluation the representation | |
# be replicated for conveniency. In case backward RNN is used | |
# it is not done. | |
# Shape if mode == evaluation | |
# (max_seq_len, batch_size, dim) | |
# Shape if mode != evaluation | |
# (n_samples, dim) | |
if not self.state['search']: | |
if mode == Decoder.EVALUATION: | |
c = PadLayer(y.shape[0])(c) | |
else: | |
assert step_num | |
c_pos = TT.minimum(step_num, c.shape[0] - 1) | |
# Low rank embeddings of all the input words. | |
# Shape if mode == evaluation | |
# (n_words, rank_n_approx), | |
# Shape if mode != evaluation | |
# (n_samples, rank_n_approx) | |
approx_embeddings = self.approx_embedder(y) | |
# Low rank embeddings are projected to contribute | |
# to input, reset and update signals. | |
# All the shapes if mode == evaluation: | |
# (n_words, dim) | |
# where: n_words = max_seq_len * batch_size | |
# All the shape if mode != evaluation: | |
# (n_samples, dim) | |
input_signals = [] | |
reset_signals = [] | |
update_signals = [] | |
for level in range(self.num_levels): | |
# Contributions directly from input words. | |
input_signals.append(self.input_embedders[level](approx_embeddings)) | |
update_signals.append(self.update_embedders[level](approx_embeddings)) | |
reset_signals.append(self.reset_embedders[level](approx_embeddings)) | |
# Contributions from the encoded source sentence. | |
if not self.state['search']: | |
current_c = c if mode == Decoder.EVALUATION else c[c_pos] | |
input_signals[level] += self.decode_inputers[level](current_c) | |
update_signals[level] += self.decode_updaters[level](current_c) | |
reset_signals[level] += self.decode_reseters[level](current_c) | |
# Hidden layers' initial states. | |
# Shapes if mode == evaluation: | |
# (batch_size, dim) | |
# Shape if mode != evaluation: | |
# (n_samples, dim) | |
init_states = given_init_states | |
if not init_states: | |
init_states = [] | |
for level in range(self.num_levels): | |
init_c = c[0, :, -self.state['dim']:] | |
init_states.append(self.initializers[level](init_c)) | |
# Hidden layers' states. | |
# Shapes if mode == evaluation: | |
# (seq_len, batch_size, dim) | |
# Shapes if mode != evaluation: | |
# (n_samples, dim) | |
hidden_layers = [] | |
contexts = [] | |
# Default value for alignment must be smth computable | |
alignment = TT.zeros((1,)) | |
for level in range(self.num_levels): | |
if level > 0: | |
input_signals[level] += self.inputers[level](hidden_layers[level - 1]) | |
update_signals[level] += self.updaters[level](hidden_layers[level - 1]) | |
reset_signals[level] += self.reseters[level](hidden_layers[level - 1]) | |
add_kwargs = (dict(state_before=init_states[level]) | |
if mode != Decoder.EVALUATION | |
else dict(init_state=init_states[level], | |
batch_size=y.shape[1] if y.ndim == 2 else 1, | |
nsteps=y.shape[0])) | |
if self.state['search']: | |
add_kwargs['c'] = c | |
add_kwargs['c_mask'] = c_mask | |
add_kwargs['return_alignment'] = self.compute_alignment | |
if mode != Decoder.EVALUATION: | |
add_kwargs['step_num'] = step_num | |
result = self.transitions[level]( | |
input_signals[level], | |
mask=y_mask, | |
gater_below=none_if_zero(update_signals[level]), | |
reseter_below=none_if_zero(reset_signals[level]), | |
one_step=mode != Decoder.EVALUATION, | |
use_noise=mode == Decoder.EVALUATION, | |
**add_kwargs) | |
if self.state['search']: | |
if self.compute_alignment: | |
#This implicitly wraps each element of result.out with a Layer to keep track of the parameters. | |
#It is equivalent to h=result[0], ctx=result[1] etc. | |
h, ctx, alignment = result | |
if mode == Decoder.EVALUATION: | |
alignment = alignment.out | |
else: | |
#This implicitly wraps each element of result.out with a Layer to keep track of the parameters. | |
#It is equivalent to h=result[0], ctx=result[1] | |
h, ctx = result | |
else: | |
h = result | |
if mode == Decoder.EVALUATION: | |
ctx = c | |
else: | |
ctx = ReplicateLayer(given_init_states[0].shape[0])(c[c_pos]).out | |
hidden_layers.append(h) | |
contexts.append(ctx) | |
import pdb;pdb.set_trace() | |
# In hidden_layers we do no have the initial state, but we need it. | |
# Instead of it we have the last one, which we do not need. | |
# So what we do is discard the last one and prepend the initial one. | |
if mode == Decoder.EVALUATION: | |
for level in range(self.num_levels): | |
hidden_layers[level].out = TT.concatenate([ | |
TT.shape_padleft(init_states[level].out), | |
hidden_layers[level].out])[:-1] | |
# The output representation to be fed in softmax. | |
# Shape if mode == evaluation | |
# (n_words, dim_r) | |
# Shape if mode != evaluation | |
# (n_samples, dim_r) | |
# ... where dim_r depends on 'deep_out' option. | |
readout = self.repr_readout(contexts[0]) | |
for level in range(self.num_levels): | |
if mode != Decoder.EVALUATION: | |
read_from = init_states[level] | |
else: | |
read_from = hidden_layers[level] | |
read_from_var = read_from if type(read_from) == theano.tensor.TensorVariable else read_from.out | |
if read_from_var.ndim == 3: | |
read_from_var = read_from_var[:,:,:self.state['dim']] | |
else: | |
read_from_var = read_from_var[:,:self.state['dim']] | |
if type(read_from) != theano.tensor.TensorVariable: | |
read_from.out = read_from_var | |
else: | |
read_from = read_from_var | |
readout += self.hidden_readouts[level](read_from) | |
if self.state['bigram']: | |
if mode != Decoder.EVALUATION: | |
check_first_word = (y > 0 | |
if self.state['check_first_word'] | |
else TT.ones((y.shape[0]), dtype="float32")) | |
# padright is necessary as we want to multiply each row with a certain scalar | |
readout += TT.shape_padright(check_first_word) * self.prev_word_readout(approx_embeddings).out | |
else: | |
if y.ndim == 1: | |
readout += Shift()(self.prev_word_readout(approx_embeddings).reshape( | |
(y.shape[0], 1, self.state['dim']))) | |
else: | |
# This place needs explanation. When prev_word_readout is applied to | |
# approx_embeddings the resulting shape is | |
# (n_batches * sequence_length, repr_dimensionality). We first | |
# transform it into 3D tensor to shift forward in time. Then | |
# reshape it back. | |
readout += Shift()(self.prev_word_readout(approx_embeddings).reshape( | |
(y.shape[0], y.shape[1], self.state['dim']))).reshape( | |
readout.out.shape) | |
for fun in self.output_nonlinearities: | |
readout = fun(readout) | |
if mode == Decoder.SAMPLING: | |
sample = self.output_layer.get_sample( | |
state_below=readout, | |
temp=T) | |
# Current SoftmaxLayer.get_cost is stupid, | |
# that's why we have to reshape a lot. | |
self.output_layer.get_cost( | |
state_below=readout.out, | |
temp=T, | |
target=sample) | |
log_prob = self.output_layer.cost_per_sample | |
return [sample] + [log_prob] + hidden_layers | |
elif mode == Decoder.BEAM_SEARCH: | |
return self.output_layer( | |
state_below=readout.out, | |
temp=T).out | |
elif mode == Decoder.EVALUATION: | |
return (self.output_layer.train( | |
state_below=readout, | |
target=y, | |
mask=y_mask, | |
reg=None), | |
alignment) | |
else: | |
raise Exception("Unknown mode for build_decoder") | |
def sampling_step(self, *args): | |
"""Implements one step of sampling | |
Args are necessary since the number (and the order) of arguments can vary""" | |
args = iter(args) | |
# Arguments that correspond to scan's "sequences" parameteter: | |
step_num = next(args) | |
assert step_num.ndim == 0 | |
# Arguments that correspond to scan's "outputs" parameteter: | |
prev_word = next(args) | |
assert prev_word.ndim == 1 | |
# skip the previous word log probability | |
assert next(args).ndim == 1 | |
prev_hidden_states = [next(args) for k in range(self.num_levels)] | |
assert prev_hidden_states[0].ndim == 2 | |
# Arguments that correspond to scan's "non_sequences": | |
c = next(args) | |
assert c.ndim == 2 | |
T = next(args) | |
assert T.ndim == 0 | |
decoder_args = dict(given_init_states=prev_hidden_states, T=T, c=c) | |
sample, log_prob = self.build_decoder(y=prev_word, step_num=step_num, mode=Decoder.SAMPLING, **decoder_args)[:2] | |
hidden_states = self.build_decoder(y=sample, step_num=step_num, mode=Decoder.SAMPLING, **decoder_args)[2:] | |
return [sample, log_prob] + hidden_states | |
def build_initializers(self, c): | |
return [init(c).out for init in self.initializers] | |
def build_sampler(self, n_samples, n_steps, T, c): | |
states = [TT.zeros(shape=(n_samples,), dtype='int64'), | |
TT.zeros(shape=(n_samples,), dtype='float32')] | |
init_c = c[0, -self.state['dim']:] | |
states += [ReplicateLayer(n_samples)(init(init_c).out).out for init in self.initializers] | |
if not self.state['search']: | |
c = PadLayer(n_steps)(c).out | |
# Pad with final states | |
non_sequences = [c, T] | |
outputs, updates = theano.scan(self.sampling_step, | |
outputs_info=states, | |
non_sequences=non_sequences, | |
sequences=[TT.arange(n_steps, dtype="int64")], | |
n_steps=n_steps, | |
name="{}_sampler_scan".format(self.prefix)) | |
return (outputs[0], outputs[1]), updates | |
def build_next_probs_predictor(self, c, step_num, y, init_states): | |
return self.build_decoder(c, y, mode=Decoder.BEAM_SEARCH, | |
given_init_states=init_states, step_num=step_num) | |
def build_next_states_computer(self, c, step_num, y, init_states): | |
return self.build_decoder(c, y, mode=Decoder.SAMPLING, | |
given_init_states=init_states, step_num=step_num)[2:] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment