Skip to content

Instantly share code, notes, and snippets.

@xjunko
Created September 18, 2025 18:14
Show Gist options
  • Save xjunko/9086bb2cb2d6074a31896a3d2352c58d to your computer and use it in GitHub Desktop.
Save xjunko/9086bb2cb2d6074a31896a3d2352c58d to your computer and use it in GitHub Desktop.
A very-very cut down version of the markovify library, catered to chatlogs.
"""mini-markov.py - a somewhat opinionated markovify-inspired text generator."""
__author__ = "xjunko"
__license__ = "WTFPL"
import bisect
import random
import re
from typing import Callable, Iterable, TypeAlias
import unidecode
Weight: TypeAlias = dict[str, int]
State: TypeAlias = tuple[str, ...]
Model: TypeAlias = dict[State, Weight]
class Chain:
BEGIN: str = "___BEGIN__"
END: str = "___END__"
def __init__(self, data: list[list[str]], *, state_size: int = 2) -> None:
self.state_size: int = state_size
self.model: Model = self.build(data, state_size)
self.begin_choices: list[str] = []
self.begin_cumdist: list[int] = []
self.compute()
@staticmethod
def accumulate(
iterable: Iterable, method: Callable = lambda x, y: x + y
) -> Iterable:
it = iter(iterable)
total = next(it)
yield total
for elem in it:
total = method(total, elem)
yield total
@staticmethod
def compile_next(data: Weight) -> tuple[list[str], list[int]]:
words = list(data.keys())
cum = list(Chain.accumulate(data.values()))
return words, cum
@staticmethod
def build(data: list[list[str]], state_size: int) -> Model:
model: Model = {}
for run in data:
items: list[str] = ([Chain.BEGIN] * state_size) + run + [Chain.END]
for i in range(len(run) + 1):
state: State = tuple(items[i : i + state_size])
follow: str = items[i + state_size]
if state not in model:
model[state] = {}
if follow not in model[state]:
model[state][follow] = 0
model[state][follow] += 1
return model
def begin_state(self) -> State:
return State([Chain.BEGIN] * self.state_size)
def compute(self) -> None:
begin_state = self.begin_state()
choices, cummulative_distance = Chain.compile_next(self.model[begin_state])
self.begin_choices = choices
self.begin_cumdist = cummulative_distance
def move(self, state: State) -> str:
if state == self.begin_state():
choices = self.begin_choices
cumdist = self.begin_cumdist
else:
choices, weights = zip(*self.model[state].items())
cumdist = list(Chain.accumulate(weights))
r: float = random.random() * cumdist[-1]
return choices[bisect.bisect(cumdist, r)]
def gen(self, *, init_state: State | None = None) -> Iterable[str]:
state = init_state or self.begin_state()
while True:
next_word: str = self.move(state)
if next_word == Chain.END:
break
yield next_word
state = tuple(state[1:]) + (next_word,)
def walk(self, *, init_state: State | None = None) -> list[str]:
return list(self.gen(init_state=init_state))
class Text:
REJECT: re.Pattern = re.compile(r"(^')|('$)|\s'|'\s|[\"(\(\)\[\])]")
def __init__(self, data: str, *, state_size: int = 2) -> None:
self.state_size: int = state_size
self.parsed_sentences: list[list[str]] = self.parse(data)
self.rejoined_text: str = " ".join(
map(lambda x: " ".join(x), self.parsed_sentences)
)
self.chain: Chain = Chain(self.parsed_sentences, state_size=state_size)
@staticmethod
def sentence_input(s: str) -> bool:
if len(s.strip()) == 0:
return False
decoded: str = unidecode.unidecode(s)
if Text.REJECT.search(decoded):
return False
return True
def verify(self, words: list[str], mor: float = 0.7, mot: float = 15) -> bool:
overlap_ratio = round(mor * len(words))
overlap_max = min(mot, overlap_ratio)
overlap_over = overlap_max + 1
gram_count = max((len(words) - overlap_max), 1)
grams = [words[i : i + overlap_over] for i in range(int(gram_count))]
for g in grams:
gram_joined = " ".join(g)
if gram_joined in self.rejoined_text:
return False
return True
def parse(self, data: str) -> list[list[str]]:
sentences: list[str] = data.split("\n")
passing = filter(self.sentence_input, sentences)
runs = map(lambda x: x.split(), passing)
return list(runs)
def generate(
self,
tries: int = 10,
*,
init_state: State | None = None,
min_words: int = 0,
max_words: int = 100
) -> str | None:
prefix: list[str] = []
if init_state:
for word in init_state:
if word == Chain.BEGIN:
continue
prefix.append(word)
for _ in range(tries):
words: list[str] = prefix + self.chain.walk(init_state=init_state)
if len(words) > max_words or len(words) < min_words:
continue
if self.verify(words):
return " ".join(words)
return None
def main() -> int:
corpus: str = open("data/myst1a.txt", encoding="utf-8").read()
test_model: Text = Text(corpus)
while True:
print(test_model.generate())
if input() == "e":
break
return 0
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment