Last active
October 16, 2021 18:49
-
-
Save kastnerkyle/97120046d0aa8f49c3fce03b844329d7 to your computer and use it in GitHub Desktop.
Not-so-minimal anymore (check the early commits) beam search example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Kyle Kastner | |
# License: BSD 3-Clause | |
# See core implementations here http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html | |
# Also includes a reduction of the post by Yoav Goldberg to a script | |
# markov_lm.py | |
# https://gist.github.com/yoavg/d76121dfde2618422139 | |
# These datasets can be a lot of fun... | |
# | |
# https://github.com/frnsys/texts | |
# python minimal_beamsearch.py 2600_phrases_for_effective_performance_reviews.txt -o 5 -d 0 | |
# | |
# Download kjv.txt from http://www.ccel.org/ccel/bible/kjv.txt | |
# python minimal_beamsearch.py kjv.txt -o 5 -d 2 -r 2145 | |
# Snippet: | |
# Queen ording found Raguel: I kill. | |
# THROUGH JESUS OF OUR BRETHREN, AND PEACE, | |
# | |
# NUN. | |
import numpy as np | |
import heapq | |
from collections import defaultdict, Counter | |
import collections | |
import os | |
import sys | |
import argparse | |
import cPickle as pickle | |
import time | |
from itertools import izip | |
class Beam(object): | |
""" | |
From http://geekyisawesome.blogspot.ca/2016/10/using-beam-search-to-generate-most.html | |
For comparison of prefixes, the tuple (prefix_probability, complete_sentence) is used. | |
This is so that if two prefixes have equal probabilities then a complete sentence is preferred | |
over an incomplete one since (0.5, False, whatever_prefix) < (0.5, True, some_other_prefix) | |
""" | |
def __init__(self, beam_width, init_beam=None, use_log=True, | |
stochastic=False, temperature=1.0, random_state=None): | |
if init_beam is None: | |
self.heap = list() | |
else: | |
self.heap = init_beam | |
heapq.heapify(self.heap) | |
self.stochastic = stochastic | |
self.random_state = random_state | |
self.temperature = temperature | |
# use_log currently unused... | |
self.use_log = use_log | |
self.beam_width = beam_width | |
def add(self, score, complete, prob, prefix): | |
heapq.heappush(self.heap, (score, complete, prob, prefix)) | |
while len(self.heap) > self.beam_width: | |
if self.stochastic: | |
# same whether logspace or no? | |
probs = np.array([h[0] for h in self.heap]) | |
probs = probs / self.temperature | |
e_x = np.exp(probs - np.max(probs)) | |
s_x = e_x / e_x.sum() | |
is_x = 1. - s_x | |
is_x = is_x / is_x.sum() | |
to_remove = self.random_state.multinomial(1, is_x).argmax() | |
completed = [n for n, h in enumerate(self.heap) if h[1] == True] | |
# Don't remove completed sentences by randomness | |
if to_remove not in completed: | |
# there must be a faster way... | |
self.heap.pop(to_remove) | |
heapq.heapify(self.heap) | |
else: | |
heapq.heappop(self.heap) | |
else: | |
# remove lowest score from heap | |
heapq.heappop(self.heap) | |
def __iter__(self): | |
return iter(self.heap) | |
def beamsearch(probabilities_function, beam_width=10, clip_len=-1, | |
start_token="<START>", end_token="<EOS>", use_log=True, | |
renormalize=True, length_score=True, | |
diversity_score=True, | |
stochastic=False, temperature=1.0, | |
random_state=None, eps=1E-9): | |
""" | |
From http://geekyisawesome.blogspot.ca/2017/04/getting-top-n-most-probable-sentences.html | |
returns a generator, which will yield beamsearched sequences in order of their probability | |
"probabilities_function" returns a list of (next_prob, next_word) pairs given a prefix. | |
"beam_width" is the number of prefixes to keep (so that instead of keeping the top 10 prefixes you can keep the top 100 for example). | |
By making the beam search bigger you can get closer to the actual most probable sentence but it would also take longer to process. | |
"clip_len" is a maximum length to tolerate, beyond which the most probable prefix is returned as an incomplete sentence. | |
Without a maximum length, a faulty probabilities function which does not return a highly probable end token | |
will lead to an infinite loop or excessively long garbage sentences. | |
"start_token" can be a single string (token), or a sequence of tokens | |
"end_token" is a single string (token), or a sequence of tokens that signifies end of the sequence | |
"use_log, renormalize, length_score" are all related to calculation of beams to keep | |
and should improve results when True | |
"stochastic" uses a different sampling algorithm for reducing/aggregating beams | |
it should result in more diverse and interesting outputs | |
"temperature" is the softmax temperature for the underlying stochastic | |
beamsearch - the default of 1.0 is usually fine | |
"random_state" is a np.random.RandomState() object, passed when using the | |
stochastic beamsearch in order to control randomness | |
"eps" minimum probability for log-space calculations, to avoid numerical issues | |
""" | |
if stochastic: | |
if random_state is None: | |
raise ValueError("Must pass np.random.RandomState() object if stochastic=True") | |
completed_beams = 0 | |
prev_beam = Beam(beam_width - completed_beams, None, use_log, stochastic, | |
temperature, random_state) | |
try: | |
basestring | |
except NameError: | |
basestring = str | |
if isinstance(start_token, collections.Sequence) and not isinstance(start_token, basestring): | |
start_token_is_seq = True | |
else: | |
# make it a list with 1 entry | |
start_token = [start_token] | |
start_token_is_seq = False | |
if isinstance(end_token, collections.Sequence) and not isinstance(end_token, basestring): | |
end_token_is_seq = True | |
else: | |
# make it a list with 1 entry | |
end_token = [end_token] | |
end_token_is_seq = False | |
if use_log: | |
prev_beam.add(.0, False, .0, start_token) | |
else: | |
prev_beam.add(1.0, False, 1.0, start_token) | |
while True: | |
curr_beam = Beam(beam_width - completed_beams, None, use_log, stochastic, | |
temperature, random_state) | |
if renormalize: | |
sorted_prev_beam = sorted(prev_beam) | |
# renormalize by the previous minimum value in the beam | |
min_prob = sorted_prev_beam[0][0] | |
else: | |
if use_log: | |
min_prob = 0. | |
else: | |
min_prob = 1. | |
if diversity_score: | |
# get prefixes | |
pre = [r[-1][len(start_token):] for r in prev_beam] | |
base = set(pre[0]) | |
diversity_scores = [] | |
# score for first entry | |
if use_log: | |
diversity_scores.append(0.) | |
else: | |
diversity_scores.append(1.) | |
if len(pre) > 1: | |
for pre_i in pre[1:]: | |
s = set(pre_i) | |
union = base | s | |
# number of new things + (- number of repetitions) | |
sc = (len(union) - len(base)) - (len(pre_i) - len(s)) | |
# update it | |
base = union | |
if use_log: | |
diversity_scores.append(sc) | |
else: | |
diversity_scores.append(np.exp(sc)) | |
# Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them. | |
for ni, (prefix_score, complete, prefix_prob, prefix) in enumerate(prev_beam): | |
if complete == True: | |
curr_beam.add(prefix_score, True, prefix_prob, prefix) | |
else: | |
# Get probability of each possible next word for the incomplete prefix | |
for (next_prob, next_word) in probabilities_function(prefix): | |
# use eps tolerance to avoid log(0.) issues | |
if next_prob > eps: | |
n = next_prob | |
else: | |
n = eps | |
# score is renormalized prob | |
if use_log: | |
if length_score: | |
score = prefix_prob + np.log(n) + np.log(len(prefix)) - min_prob | |
else: | |
score = prefix_prob + np.log(n) - min_prob | |
if diversity_score: | |
score = score + diversity_scores[ni] | |
prob = prefix_prob + np.log(n) | |
else: | |
if length_score: | |
score = (prefix_prob * n * len(prefix)) / min_prob | |
else: | |
score = (prefix_prob * n) / min_prob | |
if diversity_score: | |
score = score * diversity_scores[ni] | |
prob = prefix_prob * n | |
if end_token_is_seq: | |
left_cmp = prefix[-len(end_token) + 1:] + [next_word] | |
right_cmp = end_token | |
else: | |
left_cmp = next_word | |
right_cmp = end_token | |
if left_cmp == right_cmp: | |
# If next word is the end token then mark prefix as complete | |
curr_beam.add(score, True, prob, prefix + [next_word]) | |
else: | |
curr_beam.add(score, False, prob, prefix + [next_word]) | |
# Get all prefixes in beam sorted by probability | |
sorted_beam = sorted(curr_beam) | |
any_removals = False | |
while True: | |
# Get highest probability prefix - heapq is sorted in ascending order | |
(best_score, best_complete, best_prob, best_prefix) = sorted_beam[-1] | |
if best_complete == True or len(best_prefix) - 1 == clip_len: | |
# If most probable prefix is a complete sentence or has a length that | |
# exceeds the clip length (ignoring the start token) then return it | |
# yield best without start token, along with probability | |
if start_token_is_seq: | |
skip = len(start_token) | |
else: | |
skip = 1 | |
if end_token_is_seq: | |
stop = None | |
else: | |
stop = -1 | |
yield (best_prefix[skip:stop], best_score, best_prob) | |
sorted_beam.pop() | |
completed_beams += 1 | |
any_removals = True | |
# If there are no more sentences in the beam then stop checking | |
if len(sorted_beam) == 0: | |
break | |
else: | |
break | |
if any_removals == True: | |
if len(sorted_beam) == 0: | |
break | |
else: | |
prev_beam = Beam(beam_width - completed_beams, sorted_beam, use_log, | |
stochastic, temperature, random_state) | |
else: | |
prev_beam = curr_beam | |
# Reduce memory on python 2 | |
if sys.version_info < (3, 0): | |
range = xrange | |
def train_char_lm(fname, order=4, temperature=1.0): | |
data = file(fname).read() | |
lm = defaultdict(Counter) | |
pad = "~" * order | |
data = pad + data | |
for i in range(len(data) - order): | |
history, char = data[i:i + order], data[i + order] | |
lm[history][char] += 1 | |
def normalize(counter): | |
# Use a proper softmax with temperature | |
t = temperature | |
ck = counter.keys() | |
cv = counter.values() | |
# Keep it in log space | |
s = float(sum([pi for pi in cv])) | |
# 0 to 1 to help numerical issues | |
p = [pi / s for pi in cv] | |
# log_space | |
p = [pi / float(t) for pi in p] | |
mx = max(p) | |
# log sum exp | |
s_p = mx + np.log(sum([np.exp(pi - mx) for pi in p])) | |
# Calculate softmax in a hopefully more stable way | |
# s(xi) = exp ^ (xi / t) / sum exp ^ (xi / t) | |
# log s(xi) = log (exp ^ (xi / t) / sum exp ^ (xi / t)) | |
# log s(xi) = log exp ^ (xi / t) - log sum exp ^ (xi / t) | |
# with pi = xi / t | |
# with s_p = log sum exp ^ (xi / t) | |
# log s(xi) = pi - s_p | |
# s(xi) = np.exp(pi - s_p) | |
p = [np.exp(pi - s_p) for pi in p] | |
return [(pi, ci) for ci, pi in zip(ck, p)] | |
outlm = {hist: normalize(chars) for hist, chars in lm.iteritems()} | |
return outlm | |
def generate_letter(lm, history, order, stochastic, random_state): | |
history = history[-order:] | |
dist = lm[history] | |
if stochastic: | |
x = random_state.rand() | |
for v, c in dist: | |
x = x - v | |
if x <= 0: | |
return c | |
# randomize choice if it all failed | |
li = list(range(len(dist))) | |
random_state.shuffle(li) | |
_, c = dist[li[0]] | |
else: | |
probs = np.array([d[0] for d in dist]) | |
ii = np.argmax(probs) | |
_, c = dist[ii] | |
return c | |
def step_text(lm, order, stochastic, random_seed, history=None, end=None, | |
beam_width=1, n_letters=1000, verbose=False): | |
# beam_width argument is ignored, as is end, and verbose | |
if history is None or history == "<START>": | |
history = "~" * order | |
else: | |
history = "".join(history).decode("string_escape") | |
out = [] | |
random_state = np.random.RandomState(random_seed) | |
for i in range(n_letters): | |
c = generate_letter(lm, history, order, stochastic, random_state) | |
history = history[-order:] + c | |
out.append(c) | |
# return list to match beam_text | |
return ["".join(out)] | |
def beam_text(lm, order, stochastic, random_seed, history=None, | |
end=None, beam_width=10, n_letters=1000, verbose=False): | |
def pf(prefix): | |
history = prefix[-order:] | |
# lm wants key as a single string | |
k = "".join(history).decode("string_escape") | |
# sometimes the distribution "dead-ends"... | |
try: | |
dist = lm[k] | |
except KeyError: | |
alt_keys = [i for i in lm.keys() | |
if "".join(prefix[-order:-1]) in i | |
and "".join(prefix[-order-1:-1]) != i] | |
# if no alternates, start from a random place | |
if len(alt_keys) == 0: | |
# choose a key at semi-random | |
ak = lm.keys() | |
dist = lm[ak[random_seed % len(ak)]] | |
else: | |
dist = lm[alt_keys[0]] | |
return dist | |
if history is None or history == "<START>": | |
start_token = ["~"] * order | |
else: | |
start_token = history | |
if len(start_token) != order: | |
raise ValueError("Start length must match order setting of {}! {} is length {}".format(order, history, len(history))) | |
if end is None: | |
end_token = "<EOS>" | |
else: | |
end_token = end | |
random_state = np.random.RandomState(random_seed) | |
b = beamsearch(pf, beam_width, start_token=start_token, | |
end_token=end_token, | |
clip_len=n_letters, | |
stochastic=stochastic, | |
diversity_score=True, | |
random_state=random_state) | |
# it is a generator but do this so that function prototypes are consistent | |
all_r = [] | |
for r in b: | |
all_r.append((r[0], r[1], r[2])) | |
# reorder so final scoring is matched (rather than completion order) | |
all_r = sorted(all_r, key=lambda x: x[1]) | |
returns = [] | |
for r in all_r: | |
s_r = "".join(r[0]) | |
if verbose: | |
s_r = s_r + "\nScore: {}".format(r[1]) + "\nProbability: {}".format(r[2]) | |
returns.append(s_r) | |
# return list of all beams, ordered by score | |
return returns | |
if __name__ == "__main__": | |
default_order = 6 | |
default_temperature = 1.0 | |
default_beamwidth = 10 | |
default_start = "<START>" | |
default_end = "<EOS>" | |
default_beamwidth = 10 | |
default_decoder = 0 | |
default_randomseed = 1999 | |
default_maxlength = 500 | |
default_cache = 1 | |
default_print = 1 | |
default_verbose = 1 | |
# TODO: Faster cache | |
parser = argparse.ArgumentParser(description="A Markov chain character level language model with beamsearch decoding", | |
epilog="Simple usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10\nFull usage:\n python minimal_beamsearch.py shakespeare_input.txt -o 10 -d 0 -s 'HOLOFERNES' -e 'crew?\\n' -r 2177", | |
formatter_class=argparse.RawTextHelpFormatter) | |
parser.add_argument("filepath", help="Path to file to use for language modeling. For an example file, try downloading\nhttp://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", default=None) | |
parser.add_argument("-o", "--order", help="Markov chain order, higher will make better text but takes longer to process.\nDefault {}".format(default_order), default=default_order) | |
parser.add_argument("-t", "--temperature", help="Temperature for Markov chain softmax, higher is more random, lower more static.\nDefault {}".format(default_temperature), default=default_temperature) | |
parser.add_argument("-d","--decoder", help="Decoder for Markov chain, 0 is stochastic beamsearch, 1 is argmax beamsearch, 2 is sampled next-step, 3 is argmax next-step.\nDefault {}".format(default_decoder), default=default_decoder) | |
parser.add_argument("-b", "--beamwidth", help="Beamwidth to use for beamsearch.\nDefault {}".format(default_beamwidth), default=default_beamwidth) | |
parser.add_argument("-r", "--randomseed", help="Random seed to initialize randomness.\nDefault {}".format(default_randomseed), default=default_randomseed) | |
parser.add_argument("-s", "--starttoken", help="Start sequence token. Can be a string such as 'hello\\n', extra padding will be inferred from the data.\nDefault {}".format(default_start), default=default_start) | |
parser.add_argument("-e", "--endtoken", help="Random seed to initialize randomness. Can be a string such as 'goodbye\\n'.\nDefault {}".format(default_end), default=default_end) | |
parser.add_argument("-m", "--maxlength", help="Max generation length.\nDefault {}".format(default_maxlength), default=default_maxlength) | |
parser.add_argument("-c", "--cache", help="Whether to cache models for faster use.\nDefault {}".format(default_cache), default=default_cache) | |
parser.add_argument("-a", "--allbeams", help="Print all beams for beamsearch, 0 for top only, 1 for all.\nDefault {}".format(default_print), default=default_print) | |
parser.add_argument("-v", "--verbose", help="Print the score and probability for beams.\nDefault {}".format(default_verbose), default=default_verbose) | |
args = parser.parse_args() | |
if args.filepath is None: | |
raise ValueError("No text filepath provided!") | |
else: | |
fpath = args.filepath | |
if not os.path.exists(fpath): | |
raise ValueError("Unable to find file at %s" % fpath) | |
decoder_settings = [0, 1, 2, 3] | |
decoder = int(args.decoder) | |
# TODO: gumbel-max in stochastic beam decoder...? | |
beam_width = int(args.beamwidth) | |
temperature = float(args.temperature) | |
random_seed = int(args.randomseed) | |
maxlength = int(args.maxlength) | |
allbeams = int(args.allbeams) | |
verbose = int(args.verbose) | |
order = int(args.order) | |
if order < 1: | |
raise ValueError("Order must be greater than 1! Was set to {}".format(order)) | |
cache = int(args.cache) | |
if cache not in [0, 1]: | |
raise ValueError("Cache must be either 0 (no save) or 1 (save)! Was set to {}".format(cache)) | |
start_token = args.starttoken.decode("string_escape") | |
if start_token != default_start: | |
user_start_token = True | |
else: | |
user_start_token = False | |
end_token = args.endtoken.decode("string_escape") | |
if end_token != default_end: | |
user_end_token = True | |
else: | |
user_end_token = False | |
if decoder == 0: | |
# stochastic beam | |
stochastic = True | |
decode_fun = beam_text | |
type_tag = "Stochastic beam search, beam width {}, Markov order {}, temperature {}, seed {}".format(beam_width, order, temperature, random_seed) | |
elif decoder == 1: | |
# argmax beam | |
stochastic = False | |
decode_fun = beam_text | |
type_tag = "Argmax beam search, beam width {}, Markov order {}".format(beam_width, order) | |
elif decoder == 2: | |
# stochastic next-step | |
stochastic = True | |
decode_fun = step_text | |
type_tag = "Stochastic next step, Markov order {}, temperature {}, seed {}".format(order, temperature, random_seed) | |
elif decoder == 3: | |
# argmax next-step | |
stochastic = False | |
decode_fun = step_text | |
type_tag = "Argmax next step, Markov order {}".format(order) | |
else: | |
raise ValueError("Decoder must be 0, 1, 2, or 3! Was set to {}".format(decoder)) | |
# only things that affect the language model are training data, temperature, order | |
cached_name = "model_{}_t{}_o{}.pkl".format("".join(fpath.split(".")[:-1]), str(temperature).replace(".", "pt"), order) | |
if cache == 1 and os.path.exists(cached_name): | |
print("Found cached model at {}, loading...".format(cached_name)) | |
start_time = time.time() | |
with open(cached_name, "rb") as f: | |
lm = pickle.load(f) | |
""" | |
# codec troubles :( | |
with open(cached_name, "r") as f: | |
lm = json.load(f, encoding="latin1") | |
""" | |
stop_time = time.time() | |
print("Time to load: {} s".format(stop_time - start_time)) | |
else: | |
start_time = time.time() | |
lm = train_char_lm(fpath, order=order, | |
temperature=temperature) | |
stop_time = time.time() | |
print("Time to train: {} s".format(stop_time - start_time)) | |
if cache == 1: | |
print("Caching model now...") | |
with open(cached_name, "wb") as f: | |
pickle.dump(lm, f) | |
""" | |
# codec troubles :( | |
with open(cached_name, "w") as f: | |
json.dump(lm, f, encoding="latin1") | |
""" | |
print("Caching complete!") | |
# All this logic to handle/match different start keys | |
rs = np.random.RandomState(random_seed) | |
if user_start_token: | |
if len(start_token) > order: | |
start_token = start_token[-order:] | |
print("WARNING: specified start token larger than order, truncating to\n{}".format(start_token)) | |
if len(start_token) <= order: | |
matching_keys = [k for k in lm.keys() if k.endswith(start_token)] | |
all_keys = [k for k in lm.keys()] | |
while True: | |
if len(matching_keys) == 0: | |
rs.shuffle(all_keys) | |
print("No matching key for `{}` in language model!".format(start_token)) | |
print("Please enter another one (suggestions in backticks)\n`{}`\n`{}`\n`{}`)".format(all_keys[0], all_keys[1], all_keys[2])) | |
line = raw_input('Prompt ("Ctrl-C" to quit): ') | |
line = line.strip() | |
if len(line) == 0: | |
continue | |
else: | |
start_token = line | |
matching_keys = [k for k in lm.keys() if k.endswith(start_token)] | |
else: | |
break | |
if len(start_token) < order: | |
# choose key at random | |
matching_keys = [k for k in lm.keys() if k.endswith(start_token)] | |
rs.shuffle(matching_keys) | |
start_token = matching_keys[0] | |
print("WARNING: start key shorter than order, set to\n`{}`".format(start_token)) | |
start_token = list(start_token) | |
if user_end_token: | |
end_token = list(end_token) | |
if allbeams == 0: | |
return_count = 1 | |
elif allbeams == 1: | |
pass | |
else: | |
raise ValueError("Unknown setting for allbeams {}".format(allbeams)) | |
if verbose == 0: | |
verbose = False | |
elif verbose == 1: | |
verbose = True | |
else: | |
raise ValueError("Unknown setting for verbose {}".format(verbose)) | |
start_time = time.time() | |
all_o = decode_fun(lm, order, stochastic, random_seed, history=start_token, | |
end=end_token, beam_width=beam_width, | |
n_letters=maxlength, verbose=verbose) | |
stop_time = time.time() | |
print(type_tag) | |
print("Time to decode: {} s".format(stop_time - start_time)) | |
print("----------") | |
if allbeams == 0: | |
all_o = [all_o[0]] | |
for n, oi in enumerate(all_o): | |
if len(all_o) > 1: | |
if n == 0: | |
print("BEAM {} (worst score)".format(n + 1)) | |
elif n != (len(all_o) - 1): | |
print("BEAM {}".format(n + 1)) | |
else: | |
print("BEAM {} (best score)".format(n + 1)) | |
print("----------") | |
if user_start_token: | |
print("".join(start_token) + oi) | |
else: | |
print(oi) | |
print("----------") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment