Skip to content

Instantly share code, notes, and snippets.

@zomux
Created July 3, 2015 05:56
Show Gist options
  • Save zomux/8d414a561cc68e0763fd to your computer and use it in GitHub Desktop.
Save zomux/8d414a561cc68e0763fd to your computer and use it in GitHub Desktop.
class Decoder(EncoderDecoderBase):
EVALUATION = 0
SAMPLING = 1
BEAM_SEARCH = 2
def __init__(self, state, rng, prefix='dec',
skip_init=False, compute_alignment=False):
self.state = state
self.rng = rng
self.prefix = prefix
self.skip_init = skip_init
self.compute_alignment = compute_alignment
# Actually there is a problem here -
# we don't make difference between number of input layers
# and outputs layers.
self.num_levels = self.state['decoder_stack']
if 'dim_mult' not in self.state:
self.state['dim_mult'] = 1.
def create_layers(self):
""" Create all elements of Decoder's computation graph"""
self.default_kwargs = dict(
init_fn=self.state['weight_init_fn'] if not self.skip_init else "sample_zeros",
weight_noise=self.state['weight_noise'],
scale=self.state['weight_scale'])
self._create_embedding_layers()
self._create_transition_layers()
self._create_inter_level_layers()
self._create_initialization_layers()
self._create_decoding_layers()
self._create_readout_layers()
if self.state['search']:
assert self.num_levels == 1
self.transitions[0].set_decoding_layers(
self.decode_inputers[0],
self.decode_reseters[0],
self.decode_updaters[0])
def _create_initialization_layers(self):
logger.debug("_create_initialization_layers")
self.initializers = [ZeroLayer()] * self.num_levels
if self.state['bias_code']:
for level in range(self.num_levels):
self.initializers[level] = MultiLayer(
self.rng,
n_in=self.state['dim'],
n_hids=[self.state['dim'] * self.state['hid_mult']],
activation=[prefix_lookup(self.state, 'dec', 'activ')],
bias_scale=[self.state['bias']],
name='{}_initializer_{}'.format(self.prefix, level),
**self.default_kwargs)
def _create_decoding_layers(self):
logger.debug("_create_decoding_layers")
self.decode_inputers = [lambda x : 0] * self.num_levels
self.decode_reseters = [lambda x : 0] * self.num_levels
self.decode_updaters = [lambda x : 0] * self.num_levels
self.back_decode_inputers = [lambda x : 0] * self.num_levels
self.back_decode_reseters = [lambda x : 0] * self.num_levels
self.back_decode_updaters = [lambda x : 0] * self.num_levels
decoding_kwargs = dict(self.default_kwargs)
decoding_kwargs.update(dict(
n_in=self.state['c_dim'],
n_hids=self.state['dim'] * self.state['dim_mult'],
activation=['lambda x:x'],
learn_bias=False))
if self.state['decoding_inputs']:
for level in range(self.num_levels):
# Input contributions
self.decode_inputers[level] = MultiLayer(
self.rng,
name='{}_dec_inputter_{}'.format(self.prefix, level),
**decoding_kwargs)
# Update gate contributions
if prefix_lookup(self.state, 'dec', 'rec_gating'):
self.decode_updaters[level] = MultiLayer(
self.rng,
name='{}_dec_updater_{}'.format(self.prefix, level),
**decoding_kwargs)
# Reset gate contributions
if prefix_lookup(self.state, 'dec', 'rec_reseting'):
self.decode_reseters[level] = MultiLayer(
self.rng,
name='{}_dec_reseter_{}'.format(self.prefix, level),
**decoding_kwargs)
def _create_readout_layers(self):
softmax_layer = self.state['softmax_layer'] if 'softmax_layer' in self.state \
else 'SoftmaxLayer'
logger.debug("_create_readout_layers")
readout_kwargs = dict(self.default_kwargs)
readout_kwargs.update(dict(
n_hids=self.state['dim'],
activation='lambda x: x',
))
self.repr_readout = MultiLayer(
self.rng,
n_in=self.state['c_dim'],
learn_bias=False,
name='{}_repr_readout'.format(self.prefix),
**readout_kwargs)
# Attention - this is the only readout layer
# with trainable bias. Should be careful with that.
self.hidden_readouts = [None] * self.num_levels
for level in range(self.num_levels):
self.hidden_readouts[level] = MultiLayer(
self.rng,
n_in=self.state['dim'],
name='{}_hid_readout_{}'.format(self.prefix, level),
**readout_kwargs)
self.prev_word_readout = 0
if self.state['bigram']:
self.prev_word_readout = MultiLayer(
self.rng,
n_in=self.state['rank_n_approx'],
n_hids=self.state['dim'],
activation=['lambda x:x'],
learn_bias=False,
name='{}_prev_readout_{}'.format(self.prefix, level),
**self.default_kwargs)
if self.state['deep_out']:
act_layer = UnaryOp(activation=eval(self.state['unary_activ']))
drop_layer = DropOp(rng=self.rng, dropout=self.state['dropout'])
self.output_nonlinearities = [act_layer, drop_layer]
self.output_layer = eval(softmax_layer)(
self.rng,
self.state['dim'] / self.state['maxout_part'],
self.state['n_sym_target'],
sparsity=-1,
rank_n_approx=self.state['rank_n_approx'],
name='{}_deep_softmax'.format(self.prefix),
use_nce=self.state['use_nce'] if 'use_nce' in self.state else False,
**self.default_kwargs)
else:
self.output_nonlinearities = []
self.output_layer = eval(softmax_layer)(
self.rng,
self.state['dim'],
self.state['n_sym_target'],
sparsity=-1,
rank_n_approx=self.state['rank_n_approx'],
name='dec_softmax',
sum_over_time=True,
use_nce=self.state['use_nce'] if 'use_nce' in self.state else False,
**self.default_kwargs)
def build_decoder(self, c, y,
c_mask=None,
y_mask=None,
step_num=None,
mode=EVALUATION,
given_init_states=None,
T=1):
"""Create the computational graph of the RNN Decoder.
:param c:
representations produced by an encoder.
(n_samples, dim) matrix if mode == sampling or
(max_seq_len, batch_size, dim) matrix if mode == evaluation
:param c_mask:
if mode == evaluation a 0/1 matrix identifying valid positions in c
:param y:
if mode == evaluation
target sequences, matrix of word indices of shape (max_seq_len, batch_size),
where each column is a sequence
if mode != evaluation
a vector of previous words of shape (n_samples,)
:param y_mask:
if mode == evaluation a 0/1 matrix determining lengths
of the target sequences, must be None otherwise
:param mode:
chooses on of three modes: evaluation, sampling and beam_search
:param given_init_states:
for sampling and beam_search. A list of hidden states
matrices for each layer, each matrix is (n_samples, dim)
:param T:
sampling temperature
"""
# Check parameter consistency
if mode == Decoder.EVALUATION:
assert not given_init_states
else:
assert not y_mask
assert given_init_states
if mode == Decoder.BEAM_SEARCH:
assert T == 1
# For log-likelihood evaluation the representation
# be replicated for conveniency. In case backward RNN is used
# it is not done.
# Shape if mode == evaluation
# (max_seq_len, batch_size, dim)
# Shape if mode != evaluation
# (n_samples, dim)
if not self.state['search']:
if mode == Decoder.EVALUATION:
c = PadLayer(y.shape[0])(c)
else:
assert step_num
c_pos = TT.minimum(step_num, c.shape[0] - 1)
# Low rank embeddings of all the input words.
# Shape if mode == evaluation
# (n_words, rank_n_approx),
# Shape if mode != evaluation
# (n_samples, rank_n_approx)
approx_embeddings = self.approx_embedder(y)
# Low rank embeddings are projected to contribute
# to input, reset and update signals.
# All the shapes if mode == evaluation:
# (n_words, dim)
# where: n_words = max_seq_len * batch_size
# All the shape if mode != evaluation:
# (n_samples, dim)
input_signals = []
reset_signals = []
update_signals = []
for level in range(self.num_levels):
# Contributions directly from input words.
input_signals.append(self.input_embedders[level](approx_embeddings))
update_signals.append(self.update_embedders[level](approx_embeddings))
reset_signals.append(self.reset_embedders[level](approx_embeddings))
# Contributions from the encoded source sentence.
if not self.state['search']:
current_c = c if mode == Decoder.EVALUATION else c[c_pos]
input_signals[level] += self.decode_inputers[level](current_c)
update_signals[level] += self.decode_updaters[level](current_c)
reset_signals[level] += self.decode_reseters[level](current_c)
# Hidden layers' initial states.
# Shapes if mode == evaluation:
# (batch_size, dim)
# Shape if mode != evaluation:
# (n_samples, dim)
init_states = given_init_states
if not init_states:
init_states = []
for level in range(self.num_levels):
init_c = c[0, :, -self.state['dim']:]
init_states.append(self.initializers[level](init_c))
# Hidden layers' states.
# Shapes if mode == evaluation:
# (seq_len, batch_size, dim)
# Shapes if mode != evaluation:
# (n_samples, dim)
hidden_layers = []
contexts = []
# Default value for alignment must be smth computable
alignment = TT.zeros((1,))
for level in range(self.num_levels):
if level > 0:
input_signals[level] += self.inputers[level](hidden_layers[level - 1])
update_signals[level] += self.updaters[level](hidden_layers[level - 1])
reset_signals[level] += self.reseters[level](hidden_layers[level - 1])
add_kwargs = (dict(state_before=init_states[level])
if mode != Decoder.EVALUATION
else dict(init_state=init_states[level],
batch_size=y.shape[1] if y.ndim == 2 else 1,
nsteps=y.shape[0]))
if self.state['search']:
add_kwargs['c'] = c
add_kwargs['c_mask'] = c_mask
add_kwargs['return_alignment'] = self.compute_alignment
if mode != Decoder.EVALUATION:
add_kwargs['step_num'] = step_num
result = self.transitions[level](
input_signals[level],
mask=y_mask,
gater_below=none_if_zero(update_signals[level]),
reseter_below=none_if_zero(reset_signals[level]),
one_step=mode != Decoder.EVALUATION,
use_noise=mode == Decoder.EVALUATION,
**add_kwargs)
if self.state['search']:
if self.compute_alignment:
#This implicitly wraps each element of result.out with a Layer to keep track of the parameters.
#It is equivalent to h=result[0], ctx=result[1] etc.
h, ctx, alignment = result
if mode == Decoder.EVALUATION:
alignment = alignment.out
else:
#This implicitly wraps each element of result.out with a Layer to keep track of the parameters.
#It is equivalent to h=result[0], ctx=result[1]
h, ctx = result
else:
h = result
if mode == Decoder.EVALUATION:
ctx = c
else:
ctx = ReplicateLayer(given_init_states[0].shape[0])(c[c_pos]).out
hidden_layers.append(h)
contexts.append(ctx)
import pdb;pdb.set_trace()
# In hidden_layers we do no have the initial state, but we need it.
# Instead of it we have the last one, which we do not need.
# So what we do is discard the last one and prepend the initial one.
if mode == Decoder.EVALUATION:
for level in range(self.num_levels):
hidden_layers[level].out = TT.concatenate([
TT.shape_padleft(init_states[level].out),
hidden_layers[level].out])[:-1]
# The output representation to be fed in softmax.
# Shape if mode == evaluation
# (n_words, dim_r)
# Shape if mode != evaluation
# (n_samples, dim_r)
# ... where dim_r depends on 'deep_out' option.
readout = self.repr_readout(contexts[0])
for level in range(self.num_levels):
if mode != Decoder.EVALUATION:
read_from = init_states[level]
else:
read_from = hidden_layers[level]
read_from_var = read_from if type(read_from) == theano.tensor.TensorVariable else read_from.out
if read_from_var.ndim == 3:
read_from_var = read_from_var[:,:,:self.state['dim']]
else:
read_from_var = read_from_var[:,:self.state['dim']]
if type(read_from) != theano.tensor.TensorVariable:
read_from.out = read_from_var
else:
read_from = read_from_var
readout += self.hidden_readouts[level](read_from)
if self.state['bigram']:
if mode != Decoder.EVALUATION:
check_first_word = (y > 0
if self.state['check_first_word']
else TT.ones((y.shape[0]), dtype="float32"))
# padright is necessary as we want to multiply each row with a certain scalar
readout += TT.shape_padright(check_first_word) * self.prev_word_readout(approx_embeddings).out
else:
if y.ndim == 1:
readout += Shift()(self.prev_word_readout(approx_embeddings).reshape(
(y.shape[0], 1, self.state['dim'])))
else:
# This place needs explanation. When prev_word_readout is applied to
# approx_embeddings the resulting shape is
# (n_batches * sequence_length, repr_dimensionality). We first
# transform it into 3D tensor to shift forward in time. Then
# reshape it back.
readout += Shift()(self.prev_word_readout(approx_embeddings).reshape(
(y.shape[0], y.shape[1], self.state['dim']))).reshape(
readout.out.shape)
for fun in self.output_nonlinearities:
readout = fun(readout)
if mode == Decoder.SAMPLING:
sample = self.output_layer.get_sample(
state_below=readout,
temp=T)
# Current SoftmaxLayer.get_cost is stupid,
# that's why we have to reshape a lot.
self.output_layer.get_cost(
state_below=readout.out,
temp=T,
target=sample)
log_prob = self.output_layer.cost_per_sample
return [sample] + [log_prob] + hidden_layers
elif mode == Decoder.BEAM_SEARCH:
return self.output_layer(
state_below=readout.out,
temp=T).out
elif mode == Decoder.EVALUATION:
return (self.output_layer.train(
state_below=readout,
target=y,
mask=y_mask,
reg=None),
alignment)
else:
raise Exception("Unknown mode for build_decoder")
def sampling_step(self, *args):
"""Implements one step of sampling
Args are necessary since the number (and the order) of arguments can vary"""
args = iter(args)
# Arguments that correspond to scan's "sequences" parameteter:
step_num = next(args)
assert step_num.ndim == 0
# Arguments that correspond to scan's "outputs" parameteter:
prev_word = next(args)
assert prev_word.ndim == 1
# skip the previous word log probability
assert next(args).ndim == 1
prev_hidden_states = [next(args) for k in range(self.num_levels)]
assert prev_hidden_states[0].ndim == 2
# Arguments that correspond to scan's "non_sequences":
c = next(args)
assert c.ndim == 2
T = next(args)
assert T.ndim == 0
decoder_args = dict(given_init_states=prev_hidden_states, T=T, c=c)
sample, log_prob = self.build_decoder(y=prev_word, step_num=step_num, mode=Decoder.SAMPLING, **decoder_args)[:2]
hidden_states = self.build_decoder(y=sample, step_num=step_num, mode=Decoder.SAMPLING, **decoder_args)[2:]
return [sample, log_prob] + hidden_states
def build_initializers(self, c):
return [init(c).out for init in self.initializers]
def build_sampler(self, n_samples, n_steps, T, c):
states = [TT.zeros(shape=(n_samples,), dtype='int64'),
TT.zeros(shape=(n_samples,), dtype='float32')]
init_c = c[0, -self.state['dim']:]
states += [ReplicateLayer(n_samples)(init(init_c).out).out for init in self.initializers]
if not self.state['search']:
c = PadLayer(n_steps)(c).out
# Pad with final states
non_sequences = [c, T]
outputs, updates = theano.scan(self.sampling_step,
outputs_info=states,
non_sequences=non_sequences,
sequences=[TT.arange(n_steps, dtype="int64")],
n_steps=n_steps,
name="{}_sampler_scan".format(self.prefix))
return (outputs[0], outputs[1]), updates
def build_next_probs_predictor(self, c, step_num, y, init_states):
return self.build_decoder(c, y, mode=Decoder.BEAM_SEARCH,
given_init_states=init_states, step_num=step_num)
def build_next_states_computer(self, c, step_num, y, init_states):
return self.build_decoder(c, y, mode=Decoder.SAMPLING,
given_init_states=init_states, step_num=step_num)[2:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment