Created
March 13, 2018 17:27
-
-
Save standy66/4dd14086133b0500d5c0e3c21debbfb2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Based on example here: https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0 | |
""" | |
import numpy as np | |
import math | |
import collections | |
NEG_INF = -float("inf") | |
def logsumexp(x, y): | |
max_val = max(x, y) | |
if max_val == NEG_INF: | |
return NEG_INF | |
lsp = math.log(math.exp(x - max_val) + math.exp(y - max_val)) | |
return max_val + lsp | |
class BeamProb: | |
__slots__ = ["blank", "label"] | |
def __init__(self, blank=NEG_INF, label=NEG_INF): | |
self.blank = blank | |
self.label = label | |
def update_label_prob(self, addendum): | |
self.label = logsumexp(self.label, addendum) | |
def update_blank_prob(self, addendum): | |
self.blank = logsumexp(self.blank, addendum) | |
@property | |
def total(self): | |
return logsumexp(self.blank, self.label) | |
def decode(probs, beam_size=100, blank=0): | |
""" | |
Performs inference for the given output probabilities. | |
Arguments: | |
probs: The output probabilities (e.g. post-softmax) for each | |
time step. Should be an array of shape (time x output dim). | |
beam_size (int): Size of the beam to use during inference. | |
blank (int): Index of the CTC blank label. | |
Returns the output label sequence and the corresponding negative | |
log-likelihood estimated by the decoder. | |
""" | |
T, S = probs.shape | |
probs = np.log(probs) | |
# Elements in the beam are (prefix, BeamProb(p_blank, p_label)) | |
# Initialize the beam with the empty sequence, a probability of 1 for | |
# ending in blank and zero for ending in actual label (in log space). | |
beam = [((), BeamProb(0.0, NEG_INF))] | |
for t in range(T): | |
# A default dictionary to store the next step candidates. | |
next_beam = collections.defaultdict(BeamProb) | |
for s in range(S): | |
# The variables prob.blank and prob.label are respectively the probabilities | |
# for the prefix when it ends in a blank or an actual label at this time step. | |
for prefix, prob in beam: | |
# If we propose a blank the prefix doesn't change. | |
# Only the probability of ending in blank gets updated. | |
if s == blank: | |
next_beam[prefix].update_blank_prob(prob.total + probs[t, s]) | |
continue | |
n_prefix = prefix + (s,) | |
if prefix and s != prefix[-1]: | |
# Extend the prefix by the new character s and add it to | |
# the beam. Only the probability of not ending in blank | |
# gets updated. | |
next_beam[n_prefix].update_label_prob(prob.total + probs[t, s]) | |
else: | |
# We don't include the previous probability of not ending | |
# in blank (prob.label) if s is repeated at the end. The CTC | |
# algorithm merges characters not separated by a blank. | |
next_beam[n_prefix].update_label_prob(prob.blank + probs[t, s]) | |
# If s is repeated at the end we also update the unchanged | |
# prefix. This is the merging case. | |
next_beam[prefix].update_label_prob(prob.label + probs[t, s]) | |
# *NB* this would be a good place to include an LM score. | |
# update_with_lm(next_beam, n_prefix) | |
# Sort and trim the beam before moving on to the next time-step. | |
beam = sorted(next_beam.items(), key=lambda x: -x[1].total)[:beam_size] | |
labels, beam_prob = beam[0] | |
return labels, -logsumexp(beam_prob.blank, beam_prob.label) | |
if __name__ == "__main__": | |
np.random.seed(3) | |
time = 200 | |
output_dim = 35 | |
probs = np.random.rand(time, output_dim) | |
probs = probs / np.sum(probs, axis=1, keepdims=True) | |
labels, score = decode(probs) | |
print("Score {:.3f} labels length: {}".format(score, len(labels))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment