Skip to content

Instantly share code, notes, and snippets.

@rgbkrk
Created January 14, 2014 17:00
Show Gist options
  • Save rgbkrk/8421724 to your computer and use it in GitHub Desktop.
Save rgbkrk/8421724 to your computer and use it in GitHub Desktop.
Just a simple in memory markov model.
import random
import bisect
import pickle
from collections import Counter
from IPython.lib.pretty import pprint
class Sentinel(object):
def __init__(self,name):
self.__name = name
def __str__(self):
return self.__name
def __repr__(self):
return "<Sentinel({})>".format(self.__name)
START = Sentinel("__start__")
END = Sentinel("__end__")
def tokenize(text, stripper=lambda x: x):
tokens = text.split()
tokens = map(stripper, tokens)
tokens.insert(0, START)
tokens.append(END)
return tokens
def make_transitions(tokens):
transitions = zip(tokens[:-1], tokens[1:])
x = {}
for transition in transitions:
ctr = x.setdefault(transition[0], Counter())
ctr.update([transition[1]])
return x
class MarkovChainModel(object):
def __init__(self):
self.model = {}
def learn(self, text):
tokens = tokenize(text)
transitions = make_transitions(tokens)
for key in transitions:
if key in self.model:
self.model[key].update(transitions[key])
else:
self.model[key] = transitions[key]
def one_round(self, prev=START):
transitions = self.model.get(prev)
if not transitions:
return " ".join(playback)
total = 0
cumulative_weights = []
transition_tuples = transitions.items()
for tup in transition_tuples:
total += tup[1]
cumulative_weights.append(total)
# Throw a dart at the line of cumulative weights
dart_toss = random.random() * total
choice = bisect.bisect(cumulative_weights, dart_toss)
return transition_tuples[choice][0]
def play(self, start=START, max_length=200):
prev = start
playback = []
while(prev != None and prev !=END and len(playback) < max_length):
prev = self.one_round(prev)
if prev != END:
playback.append(prev)
else:
return " ".join(playback)
def store(self, filename):
with open(filename, 'wb') as stored_model:
pickle.dump(self.model, stored_model)
def load(self, filename):
with open(filename, 'rb') as stored_model:
self.model = pickle.load(stored_model)
if __name__ == "__main__":
mcm = MarkovChainModel()
print(mcm.play())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment