Last active
January 19, 2023 06:10
-
-
Save amake/1a4d51235e1f2fb145fa9021c80e896e to your computer and use it in GitHub Desktop.
Markov text generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Markov Chain-driven text generator. Suitable for use with e.g. | |
tinyshakespeare: | |
https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt | |
Invoke with input text file and chain order: | |
python markov.py input.txt 5 | |
""" | |
import sys | |
import re | |
import logging | |
from random import choice | |
from collections import defaultdict | |
def iterslice(items, size): | |
"Iterate over a sequence in chunks of SIZE." | |
for i in range(0, len(items) - size + 1): | |
slse = items[i:i + size] | |
trailing = items[i + size] if i + size < len(items) else None | |
yield slse, trailing | |
class MarkovData: | |
"The Markov data itself. Call init() before using." | |
def __init__(self, raw_data, order=5): | |
self.raw_data = raw_data | |
self.order = order | |
self.chains = None | |
self.seeds = None | |
def _train(self): | |
result = defaultdict(list) | |
for item in self.raw_data: | |
for key, value in iterslice(item, self.order): | |
result[str(key)].append(value) | |
return dict(result) | |
def _get_seeds(self): | |
regex = re.compile(fr'(?:^|(?<=\s))([A-Z].{{{self.order - 1}}})') | |
return [m.group(1) | |
for item in self.raw_data | |
for m in regex.finditer(item)] | |
def init(self): | |
"Train the Markov chain and generate seeds." | |
self.chains = self._train() | |
self.seeds = self._get_seeds() | |
logging.debug('Seeds: %d', len(self.seeds)) | |
logging.debug('Chain keys: %d', len(self.chains)) | |
logging.debug('Random key: %s', choice(list(self.chains.keys()))) | |
def generate(self): | |
"""Generate output, which will continue until a natural break occurs | |
per the training data.""" | |
seed = choice(self.seeds) | |
yield seed | |
while True: | |
values = self.chains.get(seed) | |
if not values: | |
break | |
value = choice(values) | |
if not value: | |
break | |
yield value | |
seed = seed[1:] + value | |
def main(): | |
in_file, order, = sys.argv[1:] | |
with open(in_file, encoding='utf-8') as in_data: | |
data = MarkovData([in_data.read()], int(order)) | |
data.init() | |
try: | |
for chunk in data.generate(): | |
if chunk: | |
sys.stdout.write(chunk) | |
except (BrokenPipeError, KeyboardInterrupt): | |
sys.stderr.close() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment