Created
January 14, 2014 17:00
-
-
Save rgbkrk/8421724 to your computer and use it in GitHub Desktop.
Just a simple in memory markov model.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import bisect | |
import pickle | |
from collections import Counter | |
from IPython.lib.pretty import pprint | |
class Sentinel(object): | |
def __init__(self,name): | |
self.__name = name | |
def __str__(self): | |
return self.__name | |
def __repr__(self): | |
return "<Sentinel({})>".format(self.__name) | |
START = Sentinel("__start__") | |
END = Sentinel("__end__") | |
def tokenize(text, stripper=lambda x: x): | |
tokens = text.split() | |
tokens = map(stripper, tokens) | |
tokens.insert(0, START) | |
tokens.append(END) | |
return tokens | |
def make_transitions(tokens): | |
transitions = zip(tokens[:-1], tokens[1:]) | |
x = {} | |
for transition in transitions: | |
ctr = x.setdefault(transition[0], Counter()) | |
ctr.update([transition[1]]) | |
return x | |
class MarkovChainModel(object): | |
def __init__(self): | |
self.model = {} | |
def learn(self, text): | |
tokens = tokenize(text) | |
transitions = make_transitions(tokens) | |
for key in transitions: | |
if key in self.model: | |
self.model[key].update(transitions[key]) | |
else: | |
self.model[key] = transitions[key] | |
def one_round(self, prev=START): | |
transitions = self.model.get(prev) | |
if not transitions: | |
return " ".join(playback) | |
total = 0 | |
cumulative_weights = [] | |
transition_tuples = transitions.items() | |
for tup in transition_tuples: | |
total += tup[1] | |
cumulative_weights.append(total) | |
# Throw a dart at the line of cumulative weights | |
dart_toss = random.random() * total | |
choice = bisect.bisect(cumulative_weights, dart_toss) | |
return transition_tuples[choice][0] | |
def play(self, start=START, max_length=200): | |
prev = start | |
playback = [] | |
while(prev != None and prev !=END and len(playback) < max_length): | |
prev = self.one_round(prev) | |
if prev != END: | |
playback.append(prev) | |
else: | |
return " ".join(playback) | |
def store(self, filename): | |
with open(filename, 'wb') as stored_model: | |
pickle.dump(self.model, stored_model) | |
def load(self, filename): | |
with open(filename, 'rb') as stored_model: | |
self.model = pickle.load(stored_model) | |
if __name__ == "__main__": | |
mcm = MarkovChainModel() | |
print(mcm.play()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment