Created
January 15, 2020 18:52
-
-
Save viswanathgs/2ec3e1e82fab9c90748bae0bafa5e439 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Trie: | |
def __init__(self, character, prob): | |
self.character = character | |
self.probability = prob | |
self.children = {} | |
def add_child(self, child): | |
assert self.character != '$', "Cannot add child to end token" | |
if child.character not in self.children: | |
self.children[child.character] = child | |
assert child.probability == self.children[child.character].probability, \ | |
"Child already exists" | |
return self.children[child.character] | |
def dfs(self, prefix='', prefix_prob=1.0): | |
new_prefix = prefix + self.character | |
new_prob = prefix_prob * self.probability | |
if self.character == '$': | |
return [(new_prefix, new_prob)] | |
results = [] | |
for child in self.children.values(): | |
results += child.dfs(new_prefix, new_prob) | |
return results | |
@staticmethod | |
def make_trie(raw_list): | |
root = Trie('', 1.0) | |
for item in raw_list: | |
node = root | |
for character, prob in item: | |
node = node.add_child(Trie(character, prob)) | |
return root | |
# Incorrect impl | |
class BeamSearch: | |
# Init takes Trie's root | |
def __init__(self, trie_root): | |
self.trie_root = trie_root | |
# Select k decodings with maximum probabilities | |
def max_k(self, decodings, k): | |
return sorted(decodings, key=lambda x: x['probability'], reverse=True)[:k] | |
# Apply beam search using recurrence with beam size of k | |
# Decodings at each recurrence instance represent currently | |
# selected beam outputs | |
def select_k_recur(self, decodings, k): | |
all_decodings = [] | |
# Generate new |k X 29| beams | |
for decoding in decodings: | |
for child in decoding['node'].children.values(): | |
current_decoding = { | |
'output': decoding['output'] + child.character, | |
'node': child, | |
'probability': decoding['probability'] * child.probability | |
} | |
all_decodings.append(current_decoding) | |
# Select new k candidates from all of the possible decodings above | |
selected_decodings = self.max_k(all_decodings, k) | |
final_k = k | |
filtered_decodings = [] | |
# Filter out finished beams which reached end | |
for decoding in selected_decodings: | |
if decoding['output'][-1] != '$': | |
filtered_decodings.append(decoding) | |
else: | |
# Decrease k's value accordingly and append | |
# current decoding to finalized decodings | |
self.final_decodings.append(decoding) | |
final_k -= 1 | |
if final_k != 0: | |
self.select_k_recur(filtered_decodings, final_k) | |
def k_step_decoding(self, trie_root, k): | |
current_node = trie_root | |
# All first decodings | |
current_decodings = [{'output': trie_root.character + child.character, | |
'node': child, | |
'probability': child.probability} | |
for child in current_node.children.values()] | |
# Select top k decodings with max probabilities | |
filtered_decodings = self.max_k(current_decodings, k) | |
self.final_decodings = [] | |
# Apply beam search | |
self.select_k_recur(filtered_decodings, k) | |
# Select the output with maximum probability among | |
# all final k beams that we got | |
return self.max_k(self.final_decodings, 1)[0]['output'] | |
class BeamSearch2: | |
# Init takes Trie's root | |
def __init__(self, trie_root): | |
self.trie_root = trie_root | |
# Select k decodings with maximum probabilities | |
def max_k(self, decodings, k): | |
return sorted(decodings, key=lambda x: x['probability'], reverse=True)[:k] | |
# Apply beam search using recurrence with beam size of k | |
# Decodings at each recurrence instance represent currently | |
# selected beam outputs | |
def select_k_recur(self, decodings, k): | |
# Keep track of already finished decodings | |
all_decodings = [d for d in decodings if d['node'].character == '$'] | |
# Generate new |k X 29| beams | |
terminate = True | |
for decoding in decodings: | |
for child in decoding['node'].children.values(): | |
current_decoding = { | |
'output': decoding['output'] + child.character, | |
'node': child, | |
'probability': decoding['probability'] * child.probability | |
} | |
all_decodings.append(current_decoding) | |
terminate = False # New nodes added, need to recur | |
# Select new k candidates from all of the possible decodings above | |
selected_decodings = self.max_k(all_decodings, k) | |
if terminate: | |
return selected_decodings | |
else: | |
return self.select_k_recur(selected_decodings, k) | |
def k_step_decoding(self, trie_root, k): | |
init_decodings = [{ | |
'output': '', | |
'node': trie_root, | |
'probability': 1.0, | |
}] | |
final_decodings = self.select_k_recur(init_decodings, k) | |
return self.max_k(final_decodings, 1)[0]['output'] | |
if __name__ == '__main__': | |
data = [ | |
[('c', 1.0), ('$', 0.1)], | |
[('c', 1.0), ('a', 1.0), ('t', 0.7), ('$', 1.0)], | |
[('c', 1.0), ('a', 1.0), ('s', 0.8), ('$', 0.2)], | |
] | |
root = Trie.make_trie(data) | |
k = 2 | |
print("Brute force:") | |
print(sorted(root.dfs(), key=lambda (chr, prob): prob, reverse=True)) | |
print("BeamSearch1:") | |
print(BeamSearch(root).k_step_decoding(root, k)) | |
print("BeamSearch2:") | |
print(BeamSearch2(root).k_step_decoding(root, k)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment