Created
September 18, 2025 18:14
-
-
Save xjunko/9086bb2cb2d6074a31896a3d2352c58d to your computer and use it in GitHub Desktop.
A very-very cut down version of the markovify library, catered to chatlogs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """mini-markov.py - a somewhat opinionated markovify-inspired text generator.""" | |
| __author__ = "xjunko" | |
| __license__ = "WTFPL" | |
| import bisect | |
| import random | |
| import re | |
| from typing import Callable, Iterable, TypeAlias | |
| import unidecode | |
| Weight: TypeAlias = dict[str, int] | |
| State: TypeAlias = tuple[str, ...] | |
| Model: TypeAlias = dict[State, Weight] | |
| class Chain: | |
| BEGIN: str = "___BEGIN__" | |
| END: str = "___END__" | |
| def __init__(self, data: list[list[str]], *, state_size: int = 2) -> None: | |
| self.state_size: int = state_size | |
| self.model: Model = self.build(data, state_size) | |
| self.begin_choices: list[str] = [] | |
| self.begin_cumdist: list[int] = [] | |
| self.compute() | |
| @staticmethod | |
| def accumulate( | |
| iterable: Iterable, method: Callable = lambda x, y: x + y | |
| ) -> Iterable: | |
| it = iter(iterable) | |
| total = next(it) | |
| yield total | |
| for elem in it: | |
| total = method(total, elem) | |
| yield total | |
| @staticmethod | |
| def compile_next(data: Weight) -> tuple[list[str], list[int]]: | |
| words = list(data.keys()) | |
| cum = list(Chain.accumulate(data.values())) | |
| return words, cum | |
| @staticmethod | |
| def build(data: list[list[str]], state_size: int) -> Model: | |
| model: Model = {} | |
| for run in data: | |
| items: list[str] = ([Chain.BEGIN] * state_size) + run + [Chain.END] | |
| for i in range(len(run) + 1): | |
| state: State = tuple(items[i : i + state_size]) | |
| follow: str = items[i + state_size] | |
| if state not in model: | |
| model[state] = {} | |
| if follow not in model[state]: | |
| model[state][follow] = 0 | |
| model[state][follow] += 1 | |
| return model | |
| def begin_state(self) -> State: | |
| return State([Chain.BEGIN] * self.state_size) | |
| def compute(self) -> None: | |
| begin_state = self.begin_state() | |
| choices, cummulative_distance = Chain.compile_next(self.model[begin_state]) | |
| self.begin_choices = choices | |
| self.begin_cumdist = cummulative_distance | |
| def move(self, state: State) -> str: | |
| if state == self.begin_state(): | |
| choices = self.begin_choices | |
| cumdist = self.begin_cumdist | |
| else: | |
| choices, weights = zip(*self.model[state].items()) | |
| cumdist = list(Chain.accumulate(weights)) | |
| r: float = random.random() * cumdist[-1] | |
| return choices[bisect.bisect(cumdist, r)] | |
| def gen(self, *, init_state: State | None = None) -> Iterable[str]: | |
| state = init_state or self.begin_state() | |
| while True: | |
| next_word: str = self.move(state) | |
| if next_word == Chain.END: | |
| break | |
| yield next_word | |
| state = tuple(state[1:]) + (next_word,) | |
| def walk(self, *, init_state: State | None = None) -> list[str]: | |
| return list(self.gen(init_state=init_state)) | |
| class Text: | |
| REJECT: re.Pattern = re.compile(r"(^')|('$)|\s'|'\s|[\"(\(\)\[\])]") | |
| def __init__(self, data: str, *, state_size: int = 2) -> None: | |
| self.state_size: int = state_size | |
| self.parsed_sentences: list[list[str]] = self.parse(data) | |
| self.rejoined_text: str = " ".join( | |
| map(lambda x: " ".join(x), self.parsed_sentences) | |
| ) | |
| self.chain: Chain = Chain(self.parsed_sentences, state_size=state_size) | |
| @staticmethod | |
| def sentence_input(s: str) -> bool: | |
| if len(s.strip()) == 0: | |
| return False | |
| decoded: str = unidecode.unidecode(s) | |
| if Text.REJECT.search(decoded): | |
| return False | |
| return True | |
| def verify(self, words: list[str], mor: float = 0.7, mot: float = 15) -> bool: | |
| overlap_ratio = round(mor * len(words)) | |
| overlap_max = min(mot, overlap_ratio) | |
| overlap_over = overlap_max + 1 | |
| gram_count = max((len(words) - overlap_max), 1) | |
| grams = [words[i : i + overlap_over] for i in range(int(gram_count))] | |
| for g in grams: | |
| gram_joined = " ".join(g) | |
| if gram_joined in self.rejoined_text: | |
| return False | |
| return True | |
| def parse(self, data: str) -> list[list[str]]: | |
| sentences: list[str] = data.split("\n") | |
| passing = filter(self.sentence_input, sentences) | |
| runs = map(lambda x: x.split(), passing) | |
| return list(runs) | |
| def generate( | |
| self, | |
| tries: int = 10, | |
| *, | |
| init_state: State | None = None, | |
| min_words: int = 0, | |
| max_words: int = 100 | |
| ) -> str | None: | |
| prefix: list[str] = [] | |
| if init_state: | |
| for word in init_state: | |
| if word == Chain.BEGIN: | |
| continue | |
| prefix.append(word) | |
| for _ in range(tries): | |
| words: list[str] = prefix + self.chain.walk(init_state=init_state) | |
| if len(words) > max_words or len(words) < min_words: | |
| continue | |
| if self.verify(words): | |
| return " ".join(words) | |
| return None | |
| def main() -> int: | |
| corpus: str = open("data/myst1a.txt", encoding="utf-8").read() | |
| test_model: Text = Text(corpus) | |
| while True: | |
| print(test_model.generate()) | |
| if input() == "e": | |
| break | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment