Last active
October 21, 2024 21:44
-
-
Save chao-ji/9ecc17adff1a8793b556f42ffb6a5b19 to your computer and use it in GitHub Desktop.
Easy to understand implementation of beam search algorithm used in decoder of seq2seq models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""NumPy implementation of Beam Search. Can be used for decoding in Seq2Seq | |
models or transformer. | |
See https://chao-ji.github.io/jekyll/update/2019/01/24/Beam_Search.html | |
for an in-depth disucssion. | |
""" | |
import numpy as np | |
NEG_INF = -1e9 | |
SOS_ID = 0 | |
EOS_ID = 1 | |
CUR_INDEX = "CUR_INDEX" | |
ACTIVE_SEQ = "ACTIVE_SEQ" | |
ACTIVE_LOG_PROBS = "ACTIVE_LOG_PROBS" | |
ACTIVE_CACHE = "ACTIVE_CACHE" | |
FINISHED_SEQ = "FINISHED_SEQ" | |
FINISHED_SCORES = "FINISHED_SCORES" | |
FINISHED_FLAGS = "FINISHED_FLAGS" | |
class BeamSearch(object): | |
"""Beam Search Decoder. | |
This implementation of beam search adopts the aggressive strategy -- we | |
maintain the maximum number of `beam_width` active threads of searches (i.e. | |
sequences that have not yet reached EOS_ID), even though some active searches | |
may eventually turn into finished ones. This way we can make sure that the | |
maximum number of active candidate sequences are considered in each decoding | |
step, because some of them may end up with higher scores than previously | |
finished searches (i.e. those that reached EOS_ID). | |
The loop invariants maintained over the search iterations are as follows: | |
* CUR_INDEX: the current index of the iteration. | |
* ACTIVE_SEQ: top-scoring active sequences. | |
* ACTIVE_LOG_PROBS: log-probs of ACTIVE_SEQ | |
* ACTIVE_CACHE: dict storing the cache values used during the ongoing searches | |
for active sequences. | |
* FINISHED_SEQ: top-scoring finished sequences. | |
* FINISHED_SCORES: scores (log-probs / length_norm) of FINISHED_SEQ | |
* FINISHED_FLAGS: values indicating whether entries in FINISHED_SEQ and | |
FINISHED_SCORES are real finished seqs or just placeholders. | |
""" | |
def __init__(self, | |
decoding_fn, | |
vocab_size, | |
batch_size, | |
beam_width, | |
alpha, | |
max_decode_length, | |
eos_id): | |
"""Constructor. | |
Args: | |
decoding_fn: a callable, which is the interface to the Transformer model. | |
The input arguments are: | |
ids: tensor of shape [batch_size*beam_width, 1]. | |
index: int scalar. | |
cache: nested dictionary of tensors [batch_size*beam_width, ...]. | |
The function returns a tuple of logits and the updated cache: | |
logits: a tensor of shape [batch*beam_width, vocab_size]. | |
updated_cache: nested dictionary with the same structure as the | |
input cache. | |
vocab_size: int scalar, the size of the vocabulary, used for topk | |
computation. | |
batch_size: int scalar, the inference batch size. | |
beam_width: int scalar, number of beams for beam search. | |
alpha: float scalar, defining the strength of length normalization. | |
max_decode_length: int scalar, the maximum number of steps to decode | |
a sequence. | |
eos_id: int scalar. ID of end of sentence token. | |
""" | |
self._decoding_fn = decoding_fn | |
self._vocab_size = vocab_size | |
self._batch_size = batch_size | |
self._beam_width = beam_width | |
self._alpha = alpha | |
self._max_decode_length = max_decode_length | |
self._eos_id = eos_id | |
self._doubled_beam_width = 2 * self._beam_width | |
self._length_normalization = lambda length: np.power( | |
(5. + float(length)) / 6., self._alpha) | |
def search(self, initial_ids, initial_cache): | |
"""Searches for sequences with greatest log-probs by keeping track of | |
`beam_width` most promising candidates (i.e. beams). | |
Args: | |
initial_ids: int tensor of shape [batch_size], populated with initial ids | |
(i.e. SOS_ID). | |
initial_cache: dict of entries | |
'encoder_outputs': tensor of shape [batch_size, src_seq_len, | |
hidden_size], | |
'padding_mask': tensor of shape [batch_size, 1, 1, src_seq_len], | |
and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]' | |
where the value associated with key 'layer_*' is a dict with entries | |
'k': tensor of shape [batch_size, 0, num_heads, size_per_head], | |
'v': tensor of shape [batch_size, 0, num_heads, size_per_head]. | |
'tgt_tgt_attention': tensor of shape [batch_size, num_heads, | |
0, 0], | |
'tgt_src_attention': tensor of shape [batch_size, num_heads, | |
0, src_seq_len]. | |
Returns: | |
finished_seqs: int tensor of shape [batch_size, beam_width, | |
decode_seq_len], the finished decoded sequences over all beams. | |
finished_scores: float tensor of shape [batch_size, beam_width], the | |
scores of finished decoded sequences over all beams. | |
tgt_tgt_attention: a list of `decoder_stack_size` float tensor of shape | |
[batch_size, num_heads, tgt_seq_len, tgt_seq_len], target-to-target | |
attention weights. | |
tgt_src_attention: a list of `decoder_stack_size` float tensor of shape | |
[batch_size, num_heads, tgt_seq_len, src_seq_len], target-to-source | |
attention weights. | |
""" | |
state = self._create_initial_state(initial_ids, initial_cache) | |
#finished_state = tf.while_loop( | |
# self._continue_search, self._search_step, loop_vars=[state], | |
# parallel_iterations=1, back_prop=False) | |
while self._continue_search(state): | |
state = self._search_step(state)[0] | |
finished_state = state | |
active_seqs = finished_state[ACTIVE_SEQ] | |
active_log_probs = finished_state[ACTIVE_LOG_PROBS] | |
finished_seqs = finished_state[FINISHED_SEQ] | |
finished_scores = finished_state[FINISHED_SCORES] | |
finished_flags = finished_state[FINISHED_FLAGS] | |
active_cache = finished_state[ACTIVE_CACHE] | |
# flag it True if any beam in a thread of seach in a batch is finished | |
# [batch_size] | |
finished_cond = np.any(finished_flags, axis=1) | |
# if none of the beams end with finished seqs, we return the remaining | |
# active seqs. | |
# [batch_size, beam_width, decode_seq_len] | |
finished_seqs = np.where(finished_cond[:, np.newaxis, np.newaxis], | |
finished_seqs, active_seqs) | |
# [batch_size, beam_width] | |
finished_scores = np.where(finished_cond[:, np.newaxis], | |
finished_scores, active_log_probs) | |
return finished_seqs, finished_scores | |
def _create_initial_state(self, initial_ids, initial_cache): | |
"""Creates initial loop invariant tensors and their shapes. This function | |
expands the dimensions and tiles the tensors to match beam width, so that | |
each beam has its own state (active and finished seqs, scores, and caches). | |
Args: | |
initial_ids: see `initial_ids` in `search`. | |
initial_cache: see `initial_cache` in `search`. | |
Returns: | |
state: a dict with the following entries | |
'CUR_INDEX': int scalar tensor, initialized to 0. | |
'ACTIVE_SEQ': tensor of shape [batch_size, beam_width, 1]. | |
'ACTIVE_LOG_PROBS': tensor of shape [batch_size, beam_width]. | |
'ACTIVE_CACHE': a dict of the same structure as input `initial_cache`, | |
except that each tensor is expanded and tiled to shape | |
[batch_size, beam_width, ...]. | |
'FINISHED_SEQ': tensor of shape [batch_size, beam_width, 1]. | |
'FINISHED_SCORES': tensor of shape [batch_size, beam_width]. | |
'FINISHED_FLAGS': tensor of shape [batch_size, beam_width]. | |
state_shape_invariants: a dict with the same structure as `state`, where | |
the values are the shape of the corresponding tensor. | |
""" | |
cur_index = np.array(0) | |
active_seq = _tile_beam_width(initial_ids, self._beam_width) | |
active_seq = np.expand_dims(active_seq, axis=2) | |
# set the log-probs of all beams to -inf except that the first beam set to | |
# zero, so that we are effectively using only the first beam in the first | |
# decoding step | |
# active_log_probs: [batch_size, beam_width] | |
active_log_probs = np.tile(np.array( | |
[[0.] + [-float("inf")] * (self._beam_width - 1)], dtype='float32'), | |
[self._batch_size, 1]) | |
# expand and tile tensors in `active_cache` to `beam_width` | |
active_cache = map_structure(lambda tensor: | |
_tile_beam_width(tensor, self._beam_width), initial_cache) | |
# initialize `finished_seq` and `finishe_scores` with placeholder values, | |
# and `finished_flags` with False values (i.e. no seq is finished yet). | |
finished_seq = np.zeros_like(active_seq, dtype='int32') | |
finished_scores = np.zeros_like(active_log_probs, dtype='float32') | |
finished_flags = np.zeros_like(active_log_probs, dtype='bool') | |
#print('finished_seq', finished_seq.shape) | |
#print('finished_scores', finished_scores.shape) | |
#print('finished_flags', finished_flags.shape) | |
state = {CUR_INDEX: cur_index, | |
ACTIVE_SEQ: active_seq, | |
ACTIVE_LOG_PROBS: active_log_probs, | |
ACTIVE_CACHE: active_cache, | |
FINISHED_SEQ: finished_seq, | |
FINISHED_SCORES: finished_scores, | |
FINISHED_FLAGS: finished_flags} | |
return state | |
def _continue_search(self, state): | |
"""Determines whether to keep searching or terminate. | |
We terminate the search if the following is True: | |
1. `cur_index` >= `max_decode_length` | |
2. It is True that for all concurrent searches in a batch, the worst score | |
of finished seqs over all beams > the best score of active seqs over all | |
beams -- the remaining candidate active seqs will never outscore the | |
current finished seqs (because scores of active seqs will certainly get | |
lower with the growing length). | |
Args: | |
state: a dict holding the loop invariant tensors over the decoding | |
iterations. See `_create_initial_state` for details. | |
Returns: | |
a bool scalar tensor, whether to continue search (True) or not (False). | |
""" | |
i = state[CUR_INDEX] | |
# active_log_probs: [batch_size, beam_width] | |
# finished_scores: [batch_size, beam_width] | |
# finished_flags: [batch_size, beam_width] | |
active_log_probs = state[ACTIVE_LOG_PROBS] | |
finished_scores = state[FINISHED_SCORES] | |
finished_flags = state[FINISHED_FLAGS] | |
# active_log_probs are always negative, so the best scores of active seqs | |
# are achieved when the length penalty is maximal | |
# best_active_scores: [batch_size] | |
max_length_norm = self._length_normalization(self._max_decode_length) | |
best_active_scores = active_log_probs[:, 0] / max_length_norm | |
# if there are no finished seqs in a batch, set the worst finished score to | |
# negative infinity for that batch | |
# finished_batch_flags: [batch_size], True if any beam is finished | |
# worst_finished_scores: [batch_size] | |
finished_batch_flags = np.any(finished_flags, axis=1) | |
worst_finished_scores = np.min(finished_scores, axis=1) | |
worst_finished_scores = np.where( | |
finished_batch_flags, worst_finished_scores, NEG_INF) | |
worst_finished_better_than_best_active = np.all( | |
worst_finished_scores > best_active_scores) | |
return np.logical_and( | |
i < self._max_decode_length, | |
np.logical_not(worst_finished_better_than_best_active)) | |
def _search_step(self, state): | |
"""Performs a single search step. | |
Args: | |
state: a dict holding the loop invariant tensors over the decoding | |
iterations. See `_create_initial_state` for details. | |
Returns: | |
a length-1 list holding a dict of the same structure as the input `state` | |
with updated tensors. | |
""" | |
new_seq, new_log_probs, new_cache = self._grow_active_seq(state) | |
active_state = self._get_new_active_state(new_seq, new_log_probs, new_cache) | |
finished_state = self._get_new_finished_state(state, new_seq, new_log_probs) | |
new_state = {CUR_INDEX: state[CUR_INDEX] + 1} | |
new_state.update(active_state) | |
new_state.update(finished_state) | |
return [new_state] | |
def _grow_active_seq(self, state): | |
"""Grows the search tree of the active sequences by one level, and gathers | |
the top-scoring `2 * beam_width` candidates. | |
Note: we may have UP TO `beam_width` finished candidates (i.e. ending with | |
EOS_ID) among all `vocab_size * beam_width` candidates, so collecting the | |
top-scoring `2 * beam_width` candidates would ensure that there are at least | |
`beam_width` candidates that are still active (i.e. not ending with EOS_ID). | |
Args: | |
state: a dict holding the loop invariant tensors over the decoding | |
iterations. See `_create_initial_state` for details. | |
Returns: | |
topk_seq: int tensor of shape [batch_size, doubled_beam_width, | |
cur_index + 2], the token ids of the extended top-scoring | |
`doubled_beam_width` sequences. | |
topk_log_probs: float tensor of shape [batch_size, doubled_beam_width], | |
log-probs of the extended top-scoring `doubled_beam_width` sequences. | |
new_cache: dict of entries | |
'encoder_outputs': tensor of shape [batch_size, doubled_beam_width, | |
src_seq_len, hidden_size], | |
'padding_mask': tensor of shape [batch_size, doubled_beam_width, 1, 1, | |
src_seq_len], | |
and entries with keys 'layer_0',...,'layer_[decoder_num_layers - 1]' | |
where the value associated with key 'layer_*' is a dict with entries | |
'k': tensor of shape [batch_size, doubled_beam_width, cur_index + 1, | |
num_heads, size_per_head], | |
'v': tensor of shape [batch_size, doubled_beam_width, cur_index + 1, | |
num_heads, size_per_head]. | |
'tgt_tgt_attention': tensor of shape [batch_size, doubled_beam_width, | |
num_heads, cur_index + 1, cur_index + 1], | |
'tgt_src_attention': tensor of shape [batch_size, doubled_beam_width, | |
num_heads, cur_index + 1, src_seq_len]. | |
""" | |
i = state[CUR_INDEX] | |
# active_seq: [batch_size, beam_width, cur_index + 1] | |
# active_log_probs: [batch_size, beam_width] | |
# active_cache[encoder_outputs]: [batch_size, beam_width, src_seq_len, | |
# hidden_size] | |
# active_cache[padding_mask]: [batch_size, beam_width, 1, 1, src_seq_len] | |
# active_cache[layer_L][k or v]: [batch_size, beam_width, cur_index, | |
# num_heads, size_per_head] | |
# active_cache[layer_L][tgt_tgt_attention]: [batch_size, beam_width, | |
# num_heads, cur_index, cur_index] | |
# active_cache[layer_L][tgt_src_attention]: [batch_size, beam_width, | |
# num_heads, cur_index, src_seq_len] | |
active_seq = state[ACTIVE_SEQ] | |
active_log_probs = state[ACTIVE_LOG_PROBS] | |
active_cache = state[ACTIVE_CACHE] | |
#print('active_seq', active_seq.shape) | |
#print('active_log_probs', active_log_probs.shape) | |
#print(active_seq) | |
#print(active_log_probs) | |
# flattening | |
# for `active_seq` and `active_cache`, do reshaping | |
# [batch_size, beam_width, ...] ==> [batch_size * beam_width, ...] | |
flat_active_seq = _flatten_beam_dim(active_seq) | |
flat_cache = map_structure(_flatten_beam_dim, active_cache) | |
#print('flat_active_seq', flat_active_seq.shape) | |
#print('flat_cache', flat_cache['h'].shape, flat_cache['c'].shape) | |
# flat_logits: [batch_size * beam_width, vocab_size] | |
# the `cur_index` of `k`, `v`, `tgt_tgt_attention`, `tgt_src_attention` | |
# tensors in `flat_cache` are incremented | |
flat_logits, flat_cache = self._decoding_fn( | |
flat_active_seq[:, -1:], flat_cache) | |
#print('flat_logits', flat_logits.shape) | |
# SOS should be excluded from the space of valid output tokens, so we push | |
# the logits of SOS_ID to -inf so that SOS will never appear in the decoded | |
# sequence | |
sos_mask = np.array( | |
[1] + [0] * (self._vocab_size - 1), dtype='float32') * NEG_INF | |
flat_logits += sos_mask | |
# unflattening | |
# logits: [batch_size, beam_width, vocab_size] | |
# tensors in `new_cache` now have shape [batch_size, beam_width, ...] | |
logits = _unflatten_beam_dim( | |
flat_logits, self._batch_size, self._beam_width) | |
new_cache = map_structure( | |
lambda t: _unflatten_beam_dim(t, self._batch_size, self._beam_width), | |
flat_cache) | |
# convert logits to log probs | |
#candidate_log_probs = logits - tf.reduce_logsumexp( | |
# logits, axis=2, keepdims=True) | |
candidate_log_probs = logits - np.log( | |
np.sum(np.exp(logits), axis=2, keepdims=True)) | |
# log_probs: [batch_size, beam_width, vocab_size] | |
log_probs = candidate_log_probs + np.expand_dims(active_log_probs, axis=2) | |
flat_log_probs = np.reshape(log_probs, | |
[-1, self._beam_width * self._vocab_size]) | |
# top_log_probs, topk_indices: [batch_size, doubled_beam_width] | |
# topk_log_probs, topk_indices = tf.nn.top_k( | |
# flat_log_probs, k=self._doubled_beam_width) | |
topk_log_probs, topk_indices = get_top_k( | |
flat_log_probs, k=self._doubled_beam_width) | |
# get the beam indices for the top `doubled_beam_width` candidates | |
topk_beam_indices = topk_indices // self._vocab_size | |
# topk_seq: [batch_size, doubled_beam_width, cur_index + 1] | |
# tensors in `new_cache` now have shape [batch_size, doubled_beam_width,...] | |
topk_seq, new_cache = _gather_beams( | |
[active_seq, new_cache], topk_beam_indices) | |
# append the top `doubled_beam_width` ids (`topk_ids`) to the growing active | |
# seqs (`topk_seq`) | |
# topk_ids: [batch_size, doubled_beam_width] | |
topk_ids = np.expand_dims(topk_indices % self._vocab_size, axis=2) | |
topk_seq = np.concatenate([topk_seq, topk_ids], axis=2) | |
#print('topk_ids', topk_ids.shape) | |
#print('topk_seq', topk_seq.shape) | |
return topk_seq, topk_log_probs, new_cache | |
def _get_new_active_state(self, new_seq, new_log_probs, new_cache): | |
"""Gathers the top `beam_width` active sequences from the larger pool of | |
`2 * beam_width` candidates. | |
Args: | |
new_seq: same as `topk_seq` in `_grow_active_seq`. | |
new_log_probs: same as `topk_log_probs` in `_grow_active_seq`. | |
new_cache: same as `new_cache` in `_grow_active_seq`. | |
Returns: | |
a dict with the following entries: | |
'ACTIVE_SEQ': tensor of the same shape as input `new_seq`, except the | |
beam dimension changes to `beam_width` from `2 * beam_width`. | |
'ACTIVE_LOG_PROBS': tensor of the same shape as input `new_log_probs`, | |
except the beam dimension changes to `beam_width` from | |
`2 * beam_width`. | |
'ACTIVE_CACHE': nested structure of tensors, where each tensor has the | |
same shape as counterpart in input `new_cache`, except the beam | |
dimension changes to `beam_width` from `2 * beam_width`. | |
""" | |
# [batch_size, doubled_beam_width] | |
new_active_flags = np.logical_not(new_seq[:, :, -1] == self._eos_id) | |
top_active_seq, top_active_log_probs, top_active_cache = _gather_topk( | |
[new_seq, new_log_probs, new_cache], | |
new_log_probs, new_active_flags, self._beam_width) | |
return {ACTIVE_SEQ: top_active_seq, | |
ACTIVE_LOG_PROBS: top_active_log_probs, | |
ACTIVE_CACHE: top_active_cache} | |
def _get_new_finished_state(self, state, new_seq, new_log_probs): | |
"""Gets newly finished seqs (if any) and combines them with previously | |
finished seqs, and gathers the top-scoring `beam_width` seqs. | |
Args: | |
state: a dict holding the loop invariant tensors over the decoding | |
iterations. See `_create_initial_state` for details. | |
new_seq: same as `topk_seq` in `_grow_active_seq`. | |
new_log_probs: same as `topk_log_probs` in `_grow_active_seq`. | |
Returns: | |
a dict with the following entries: | |
'FINISHED_SEQ': tensor of shape [batch_size, beam_width, cur_index + 2]. | |
'FINISHED_SCORES': tensor of shape [batch_size, beam_width]. | |
'FINISHED_FLAGS': tensor of shape [batch_size, beam_width]. | |
""" | |
i = state[CUR_INDEX] | |
# finished_seq: [batch_size, beam_width, cur_index + 1] | |
# finished_scores: [batch_size, beam_width] | |
# finished_flags: [batch_size, beam_width] | |
finished_seq = state[FINISHED_SEQ] | |
finished_scores = state[FINISHED_SCORES] | |
finished_flags = state[FINISHED_FLAGS] | |
# zero-pad the previously finished seqs to shape | |
# [batch_size, beam_width, cur_index + 2] | |
finished_seq = np.pad(finished_seq, [[0, 0], [0, 0], [0, 1]]) | |
# convert log-probs to scores by length normalization | |
new_scores = new_log_probs / self._length_normalization(i + 1) | |
# flag the newly finished seqs (if any) | |
# [batch_size, doubled_beam_width] | |
new_finished_flags = new_seq[:, :, -1] == self._eos_id | |
# combine previously finished seqs w/ those newly finished (if any) | |
# finished_seq: [batch_size, beam_width * 3, cur_index + 2] | |
# finished_scores: [batch_size, beam_width * 3] | |
# finished_flags: [batch_size, beam_width * 3] | |
#print('a', finished_seq.shape, 'b', new_seq.shape) | |
finished_seq = np.hstack([finished_seq, new_seq]) | |
finished_scores = np.hstack([finished_scores, new_scores]) | |
finished_flags = np.hstack([finished_flags, new_finished_flags]) | |
top_finished_seq, top_finished_scores, top_finished_flags = _gather_topk( | |
[finished_seq, finished_scores, finished_flags], | |
finished_scores, finished_flags, self._beam_width) | |
return {FINISHED_SEQ: top_finished_seq, | |
FINISHED_SCORES: top_finished_scores, | |
FINISHED_FLAGS: top_finished_flags} | |
def _tile_beam_width(tensor, beam_width): | |
"""Given a tensor of shape [batch_size, ...], expands its dims in axis=1 | |
and tile along axis=1. | |
Args: | |
tensor: tensor of shape [batch_size, ...] | |
beam_width: int scalar, beam width. | |
Returns: | |
tiled_tensor: tensor of shape [batch_size, beam_width, ...]. | |
""" | |
tensor = np.expand_dims(tensor, axis=1) | |
tile_dims = [1] * tensor.ndim | |
tile_dims[1] = beam_width | |
tiled_tensor = np.tile(tensor, tile_dims) | |
return tiled_tensor | |
def _flatten_beam_dim(tensor): | |
"""Collapses batch and beam dimension into a single dimension. | |
Args: | |
tensor: tensor of shape [batch_size, beam_width, ...] | |
Returns: | |
tensor of shape [batch_size * beam_width, ...] | |
""" | |
shape = list(tensor.shape) | |
shape[0] *= shape[1] | |
shape.pop(1) | |
return np.reshape(tensor, shape) | |
def _unflatten_beam_dim(tensor, batch_size, beam_width): | |
"""Un-collapses the first dimension back into batch and beam dimension. | |
Args: | |
tensor: tensor of shape [batch_size * beam_width, ...] | |
batch_size: int scalar, batch size. | |
beam_width: int scalar, beam width. | |
Returns: | |
tensor of shape [batch_size, beam_width, ...] | |
""" | |
shape = list(tensor.shape) | |
new_shape = [batch_size, beam_width] + shape[1:] | |
return np.reshape(tensor, new_shape) | |
def _gather_beams(nested, beam_indices): | |
"""Gathers beams from a nested structure of tensors according to beam indices. | |
Args: | |
nested: a dict, list, tuple or a tensor, where elements are recursively | |
dict, list, tuple or a tensor. All tensors have shape [batch_size, | |
beam_width, ...]. | |
beam_indices: int tensor of shape [batch_size, new_beam_width], holding the | |
indices of beams (not necessarily unique) to be gathered for each batch. | |
Returns: | |
an object of the same structure as `nested`, where each tensor has shape | |
[batch_size, new_beam_width, ...]. | |
""" | |
batch_size, new_beam_width = beam_indices.shape | |
# For example, given batch_size = 3 | |
# batch_indices = [[0, 0,...,0],[1, 1,...,1], ...,[2, 2,...,2]] | |
# each sublist has length `new_beam_width` | |
batch_indices = np.tile(np.arange(batch_size)[:, np.newaxis], | |
[1, new_beam_width]) | |
indices = np.stack([batch_indices, beam_indices], axis=2) | |
return map_structure( | |
lambda state: gather_nd(state, indices), nested) | |
def _gather_topk(nested, scores, flags, k): | |
"""Gathers top-k scoring valid beams (the corresponding flag is True). | |
For example, given | |
scores = [-0.32, -0.59, -0.11, -0.05, -0.96, -0.87] | |
flags = [True, False, False, True, False, True] | |
k = 4 | |
The scores for False flags will be pushed to -inf, and the results will be | |
scores = [-0.05, -0.32, -0.87, -inf] | |
flags = [True, True, True, False] | |
Note: if the num of valid seqs across all beams for each batch is less than | |
`k`, the result is padded with invalid seqs. | |
Args: | |
nested: a dict, list, tuple or a tensor, where elements are recursively | |
dict, list, tuple or a tensor. All tensors have shape [batch_size, | |
beam_width, ...]. | |
scores: float tensor of shape [batch_size, beam_width], the scores of each | |
sequence for a particular batch and beam. | |
flags: bool tensor of shape [batch_size, beam_width], indicates the validity | |
of each sequence (valid if True). | |
k: int scalar, the num of top scoring sequences (<= `beam_width`). | |
Returns: | |
an object of the same structure as `nested`, where each tensor has shape | |
[batch_size, k, ...]. | |
""" | |
# push the scores of invalid seqs to NEG_INF, so they will be placed after the | |
# valid seqs in `indices` | |
scores += np.logical_not(flags).astype('float32') * NEG_INF | |
_, indices = get_top_k(scores, k) | |
return _gather_beams(nested, indices) | |
def map_structure(fn, nested): | |
"""Recursively executes a function over elements organized in a structure | |
recursively composed of dict, list or tuple. | |
Args: | |
fn: a callable, function to be executed. | |
nested: a dict, list, tuple or a tensor, where elements are recursively | |
dict, list, tuple or a tensor. | |
Returns: | |
an object of the same structure as `nested` after applying the function. | |
""" | |
if isinstance(nested, dict): | |
d = {} | |
for k, v in nested.items(): | |
d[k] = map_structure(fn, v) | |
return d | |
elif isinstance(nested, (tuple, list)): | |
l = [] | |
for v in nested: | |
l.append(map_structure(fn, v)) | |
return tuple(l) if isinstance(nested, tuple) else l | |
else: | |
return fn(nested) | |
def get_top_k(inputs, k=1): | |
"""Find values and indices of the `k` largest entries for the last dimension. | |
Args: | |
inputs: | |
k: int scalar, num of top elements to look for along the last dimension. | |
Returns: | |
values: | |
indices: | |
""" | |
shape = inputs.shape | |
size = np.prod(shape[:-1]) | |
inputs = inputs.reshape([size, -1]) | |
indices = np.argsort(-inputs)[:, :k] | |
values = np.array([inputs[i][indices[i]] | |
for i in np.arange(size)]) | |
values = values.reshape(shape[:-1] + (-1,)) | |
indices = indices.reshape(values.shape) | |
return values, indices | |
def gather_nd(params, indices): | |
shape = indices.shape | |
indices = np.reshape(indices, [-1, shape[-1]]) | |
result = [] | |
for row in indices: | |
a = params | |
for i in row: | |
a = a[i] | |
result.append(a) | |
result = np.stack(result, axis=0) | |
new_shape = shape[:-1] + params.shape[shape[-1]:] | |
result = np.reshape(result, new_shape) | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment