Skip to content

Instantly share code, notes, and snippets.

@xjunko
Created September 18, 2025 20:19
Show Gist options
  • Save xjunko/317d66fc97d9f7e9c2587104fa65c384 to your computer and use it in GitHub Desktop.
Save xjunko/317d66fc97d9f7e9c2587104fa65c384 to your computer and use it in GitHub Desktop.
A simple markov text generator written in Rust.
use std::collections::HashMap;
use rand::Rng;
use regex::Regex;
const BEGIN: &str = "___BEGIN__";
const END: &str = "___END__";
const STATE_SIZE: usize = 2;
type State = Vec<String>;
type Weight = HashMap<String, i32>;
type Model = HashMap<State, Weight>;
#[derive(Debug, Default)]
pub struct Chain {
pub model: Model,
begin_choices: Vec<String>,
begin_weights: Vec<i32>,
}
impl Chain {
pub fn accumulate(ns: Vec<i32>) -> Vec<i32> {
let mut numbers: Vec<i32> = Vec::new();
let mut total = ns[0];
for n in ns {
numbers.push(total);
total += n;
}
numbers
}
pub fn compile_next(data: Weight) -> (Vec<String>, Vec<i32>) {
let words: Vec<String> = data.keys().cloned().collect();
let cum: Vec<i32> = Self::accumulate(data.values().cloned().collect());
(words, cum)
}
fn bisect_right<T: Ord>(slice: &[T], x: &T) -> usize {
match slice.binary_search(x) {
Ok(idx) => idx + 1, // if found, insert to the right
Err(idx) => idx, // if not found, Err gives insertion point
}
}
}
impl Chain {
pub fn new(data: Vec<Vec<String>>) -> Self {
let mut chain = Self::default();
chain.model = chain.build(data);
chain.compute();
chain
}
pub fn build(&self, data: Vec<Vec<String>>) -> Model {
let mut model: Model = HashMap::new();
for run in data {
let mut items: Vec<&str> = vec![BEGIN; STATE_SIZE];
items.extend(run.iter().map(|s| s.as_str()));
items.push(END);
for i in 0..run.len() + 1 {
let state: State = items[i..i + STATE_SIZE]
.iter()
.map(|s| (*s).to_string())
.collect::<Vec<String>>()
.try_into()
.unwrap();
let follow: &str = items[i + STATE_SIZE];
// FIXME: clones
if !model.contains_key(&state) {
model.insert(state.clone(), HashMap::new());
}
// FIXME: &str to String
if !model.get(&state).unwrap().contains_key(follow) {
model.get_mut(&state).unwrap().insert(follow.to_string(), 0);
}
*model.get_mut(&state).unwrap().get_mut(follow).unwrap() += 1;
}
}
model
}
pub fn begin_state(&self) -> State {
vec![BEGIN.to_string(); STATE_SIZE]
}
pub fn compute(&mut self) {
let begin_state = self.begin_state();
let (choices, cum) = Self::compile_next(self.model.get(&begin_state).unwrap().clone());
self.begin_choices = choices;
self.begin_weights = cum;
}
pub fn r#move(&self, state: &State) -> String {
let (mut choices, mut cumdist) = (self.begin_choices.clone(), self.begin_weights.clone());
if state != &self.begin_state() {
// FIXME: This is flawed
choices.clear();
cumdist.clear();
let mut weights: Vec<i32> = Vec::new();
for (word, weight) in self.model.get(state).unwrap() {
choices.push(word.clone());
weights.push(*weight);
}
cumdist = Self::accumulate(weights);
}
let r: f32 = rand::thread_rng().gen_range(0.0..1.0) * (*cumdist.last().unwrap() as f32);
let r_i32 = r as i32;
choices[Self::bisect_right(&cumdist, &r_i32)].clone()
}
pub fn r#gen(&self, init_state: Option<State>) -> Vec<String> {
let mut state = init_state.unwrap_or(self.begin_state());
let mut result: Vec<String> = Vec::new();
loop {
let next_word: String = self.r#move(&state);
if next_word == END {
break;
}
// FIXME: clones
result.push(next_word.clone());
state = state[1..].to_vec();
state.push(next_word.clone());
}
result
}
}
#[derive(Debug, Default)]
pub struct Text {
reject: Option<Regex>,
parsed_sentences: Vec<Vec<String>>,
rejoined_text: String,
pub chain: Chain,
}
impl Text {
pub fn sentence_input(&self, s: &str) -> bool {
if s.trim().is_empty() {
return false;
}
let decoded = unidecode::unidecode(s);
if let Some(re) = &self.reject
&& re.is_match(&decoded)
{
return false;
}
true
}
}
impl Text {
pub fn new(data: String) -> Self {
let mut text = Text::default();
text.reject = Regex::new(&format!(r"(^')|('$)|\s'|'\s|[\{}(\(\)\[\])]", '"')).ok();
text.parsed_sentences = text.parse(data);
text.rejoined_text = text
.parsed_sentences
.iter()
.map(|s| s.join(" "))
.collect::<Vec<String>>()
.join(" ");
text.chain = Chain::new(text.parsed_sentences.clone());
text
}
fn verify(&self, words: &[String], mor: f32, mot: usize) -> bool {
let overlap_ratio = (mor * words.len() as f32).round() as usize;
let overlap_max = mot.min(overlap_ratio);
let overlap_over = overlap_max + 1;
let gram_count = (words.len().saturating_sub(overlap_max)).max(1);
for i in 0..gram_count {
let end = (i + overlap_over).min(words.len());
let gram = &words[i..end];
let gram_joined = gram.join(" ");
if self.rejoined_text.contains(&gram_joined) {
return false;
}
}
true
}
pub fn parse(&self, data: String) -> Vec<Vec<String>> {
data.split("\n")
.filter(|s| self.sentence_input(s))
.map(|s| s.split_whitespace().map(|w| w.to_string()).collect())
.collect()
}
pub fn generate(&self, tries: i32, min_words: i32, max_words: i32) -> String {
for _ in 0..tries {
let words: Vec<String> = self.chain.r#gen(None);
if words.len() > max_words as usize || words.len() < min_words as usize {
continue;
}
if self.verify(&words, 0.7, 15) {
return words.join(" ");
}
}
String::new()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment