Created
May 27, 2019 00:50
-
-
Save aryamccarthy/a678289681c4c3f167b0f671797e18be to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import astuple, dataclass | |
from typing import Callable, List, NewType, Set, Tuple | |
import numpy as np | |
np.random.seed(1337) | |
# Dummy typings; feel free to switch to PyTorch. | |
Word = str | |
Phrase = NewType('Phrase', Tuple[Word]) | |
Score = float # or th.float, etc. | |
State = np.ndarray # or th.Tensor, etc. | |
@dataclass(frozen=True) | |
class Hypothesis: | |
"""Class for a node in the beam search beam.""" | |
story: List[Phrase] | |
candidates: Set[Phrase] | |
score: Score | |
state: State | |
class Model: | |
def __init__(self): | |
pass | |
def initial_state(self) -> State: | |
raise NotImplementedError | |
def score(self, next_word: Word, state: State) -> Score: | |
raise NotImplementedError | |
def transition(self, next_word: Word, state: State) -> State: | |
raise NotImplementedError | |
def score_sequence(self, phrases: List[Phrase]) -> Score: | |
raise NotImplementedError | |
class DummyModel(Model): | |
def initial_state(self) -> State: | |
return np.random.randn(4) | |
def score(self, next_word: Word, state: State) -> Score: | |
return np.random.random() | |
def transition(self, next_word: Word, state: State) -> State: | |
return state | |
def score_sequence(self, phrases: List[Phrase]) -> Score: | |
return np.random.random() | |
def dummy_heuristic(remainder_to_process: Set[Phrase]) -> Score: | |
""" | |
Just return 0, to show the API. | |
By comparison, Schmaltz et al. use | |
'a very simple unigram future cost estimate, | |
g(R) = sum[i∈R] sum[w∈xi] log p(w).' | |
""" | |
return Score(0.0) | |
def beam_search( | |
phrases: List[Phrase], # Noun phrases or phrases containing one token | |
K: int, # Beam size | |
g: Callable[[Set[Phrase]], Score], | |
model | |
): | |
M = len(phrases) # Number of phrases may not be number of tokens. | |
beams: List[Hypothesis] = [[] for _ in range(M + 1)] | |
beams[0] = [Hypothesis([], phrases, 0.0, model.initial_state())] | |
m = 0 | |
for m in range(M): # for all lengths: | |
for k in range(len(beams[m])): # for each hypothesis at this point: | |
these_beams = beams[m] | |
hypothesis = these_beams[k] | |
(story, candidates, score, state) = astuple(hypothesis) | |
for phrase in candidates: | |
new_score, new_state = score, state | |
for word in phrase: | |
new_score += model.score(word, new_state) | |
new_state = model.transition(word, new_state) | |
j = m + len(phrase) | |
new_hypothesis = Hypothesis(story + [phrase], candidates - {phrase}, new_score, new_state) | |
beams[j].append(new_hypothesis) | |
# Extract top K from model. | |
ordered = sorted(beams[j], key=lambda hyp: model.score_sequence(hyp.story) + g(hyp.candidates)) | |
beams[j] = ordered[:K] | |
return beams | |
if __name__ == '__main__': | |
from pprint import pprint # Pretty-print | |
import random | |
random.seed(1337) | |
model = DummyModel() | |
raw_tokens = "Papa ate the caviar with a spoon .".split() | |
random.shuffle(raw_tokens) | |
as_words = [Word(w) for w in raw_tokens] | |
as_phrases = [Phrase((w, )) for w in as_words] | |
phrases = set(as_phrases) | |
pprint(beam_search(phrases, 3, lambda x: 0.0, model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment