Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Created August 26, 2019 02:40
Show Gist options
  • Save erikerlandson/f71f7e1c90497d33385cf96396c29196 to your computer and use it in GitHub Desktop.
Save erikerlandson/f71f7e1c90497d33385cf96396c29196 to your computer and use it in GitHub Desktop.
pruning and compiling markovify models
class CompiledMarkovify(object):
def __init__(self, model):
def compile_next(next_dict):
words = list(next_dict.keys())
cff = np.array(list(itertools.accumulate(next_dict.values())))
return (words, cff)
chain_dict = model.chain.model
self.sxf = { state: compile_next(next_dict) for (state, next_dict) in chain_dict.items() }
self.state_size = model.state_size
self.BEGIN = '___BEGIN__'
self.END = '___END__'
def xf(self, state):
words, cff = self.sxf[state]
r = random.random() * cff[-1]
return words[cff.searchsorted(r)]
def emit(self, init = None, max_words = 100):
state = init or (self.BEGIN,) * self.state_size
seq = []
while True:
word = self.xf(state)
if word == self.END: break
seq.append((state, word))
if len(seq) > max_words: break
state = state[1:] + (word,)
return seq
def generate(self, init = None, max_tries = 10, max_words = 50, min_words = 0):
for _ in range(max_tries):
seq = self.emit(init, max_words = max_words+1)
n = len(seq)
if n < min_words or n > max_words: continue
return " ".join([word for (_, word) in seq])
return None
def prune_tail(next_dict, minfreq, tailprob):
return { word: freq for (word, freq) in next_dict.items() if ((freq >= minfreq) or (random.random() < tailprob))}
def prune_next(next_dict, chain_dict, state):
n = len(state)
sb = list(state[(-(n-1)):])
return { word: freq for (word, freq) in next_dict.items() if ((word == '___END__') or (tuple(sb + [word]) in chain_dict)) }
def prune_markovify_tail(model, minfreq = 2, tailprob = 0.5):
tdict = { state: prune_tail(next_dict, minfreq, tailprob) for (state, next_dict) in model.chain.model.items() }
chain_dict = { state: next_dict for (state, next_dict) in tdict.items() if (len(next_dict) > 0) }
lprev = len(chain_dict) + 1
while len(chain_dict) != lprev:
lprev = len(chain_dict)
tdict = { state: prune_next(next_dict, chain_dict, state) for (state, next_dict) in chain_dict.items() }
chain_dict = { state: next_dict for (state, next_dict) in tdict.items() if (len(next_dict) > 0) }
pruned_chain = markovify.Chain(None, model.state_size, model=chain_dict)
return markovify.Text(None, state_size=model.state_size, chain=pruned_chain, parsed_sentences=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment