Skip to content

Instantly share code, notes, and snippets.

@amake
Last active January 19, 2023 06:10
Show Gist options
  • Save amake/1a4d51235e1f2fb145fa9021c80e896e to your computer and use it in GitHub Desktop.
Save amake/1a4d51235e1f2fb145fa9021c80e896e to your computer and use it in GitHub Desktop.
Markov text generator
"""Markov Chain-driven text generator. Suitable for use with e.g.
tinyshakespeare:
https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
Invoke with input text file and chain order:
python markov.py input.txt 5
"""
import sys
import re
import logging
from random import choice
from collections import defaultdict
def iterslice(items, size):
"Iterate over a sequence in chunks of SIZE."
for i in range(0, len(items) - size + 1):
slse = items[i:i + size]
trailing = items[i + size] if i + size < len(items) else None
yield slse, trailing
class MarkovData:
"The Markov data itself. Call init() before using."
def __init__(self, raw_data, order=5):
self.raw_data = raw_data
self.order = order
self.chains = None
self.seeds = None
def _train(self):
result = defaultdict(list)
for item in self.raw_data:
for key, value in iterslice(item, self.order):
result[str(key)].append(value)
return dict(result)
def _get_seeds(self):
regex = re.compile(fr'(?:^|(?<=\s))([A-Z].{{{self.order - 1}}})')
return [m.group(1)
for item in self.raw_data
for m in regex.finditer(item)]
def init(self):
"Train the Markov chain and generate seeds."
self.chains = self._train()
self.seeds = self._get_seeds()
logging.debug('Seeds: %d', len(self.seeds))
logging.debug('Chain keys: %d', len(self.chains))
logging.debug('Random key: %s', choice(list(self.chains.keys())))
def generate(self):
"""Generate output, which will continue until a natural break occurs
per the training data."""
seed = choice(self.seeds)
yield seed
while True:
values = self.chains.get(seed)
if not values:
break
value = choice(values)
if not value:
break
yield value
seed = seed[1:] + value
def main():
in_file, order, = sys.argv[1:]
with open(in_file, encoding='utf-8') as in_data:
data = MarkovData([in_data.read()], int(order))
data.init()
try:
for chunk in data.generate():
if chunk:
sys.stdout.write(chunk)
except (BrokenPipeError, KeyboardInterrupt):
sys.stderr.close()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment