Created
September 18, 2025 20:19
-
-
Save xjunko/317d66fc97d9f7e9c2587104fa65c384 to your computer and use it in GitHub Desktop.
A simple markov text generator written in Rust.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use std::collections::HashMap; | |
| use rand::Rng; | |
| use regex::Regex; | |
| const BEGIN: &str = "___BEGIN__"; | |
| const END: &str = "___END__"; | |
| const STATE_SIZE: usize = 2; | |
| type State = Vec<String>; | |
| type Weight = HashMap<String, i32>; | |
| type Model = HashMap<State, Weight>; | |
| #[derive(Debug, Default)] | |
| pub struct Chain { | |
| pub model: Model, | |
| begin_choices: Vec<String>, | |
| begin_weights: Vec<i32>, | |
| } | |
| impl Chain { | |
| pub fn accumulate(ns: Vec<i32>) -> Vec<i32> { | |
| let mut numbers: Vec<i32> = Vec::new(); | |
| let mut total = ns[0]; | |
| for n in ns { | |
| numbers.push(total); | |
| total += n; | |
| } | |
| numbers | |
| } | |
| pub fn compile_next(data: Weight) -> (Vec<String>, Vec<i32>) { | |
| let words: Vec<String> = data.keys().cloned().collect(); | |
| let cum: Vec<i32> = Self::accumulate(data.values().cloned().collect()); | |
| (words, cum) | |
| } | |
| fn bisect_right<T: Ord>(slice: &[T], x: &T) -> usize { | |
| match slice.binary_search(x) { | |
| Ok(idx) => idx + 1, // if found, insert to the right | |
| Err(idx) => idx, // if not found, Err gives insertion point | |
| } | |
| } | |
| } | |
| impl Chain { | |
| pub fn new(data: Vec<Vec<String>>) -> Self { | |
| let mut chain = Self::default(); | |
| chain.model = chain.build(data); | |
| chain.compute(); | |
| chain | |
| } | |
| pub fn build(&self, data: Vec<Vec<String>>) -> Model { | |
| let mut model: Model = HashMap::new(); | |
| for run in data { | |
| let mut items: Vec<&str> = vec![BEGIN; STATE_SIZE]; | |
| items.extend(run.iter().map(|s| s.as_str())); | |
| items.push(END); | |
| for i in 0..run.len() + 1 { | |
| let state: State = items[i..i + STATE_SIZE] | |
| .iter() | |
| .map(|s| (*s).to_string()) | |
| .collect::<Vec<String>>() | |
| .try_into() | |
| .unwrap(); | |
| let follow: &str = items[i + STATE_SIZE]; | |
| // FIXME: clones | |
| if !model.contains_key(&state) { | |
| model.insert(state.clone(), HashMap::new()); | |
| } | |
| // FIXME: &str to String | |
| if !model.get(&state).unwrap().contains_key(follow) { | |
| model.get_mut(&state).unwrap().insert(follow.to_string(), 0); | |
| } | |
| *model.get_mut(&state).unwrap().get_mut(follow).unwrap() += 1; | |
| } | |
| } | |
| model | |
| } | |
| pub fn begin_state(&self) -> State { | |
| vec![BEGIN.to_string(); STATE_SIZE] | |
| } | |
| pub fn compute(&mut self) { | |
| let begin_state = self.begin_state(); | |
| let (choices, cum) = Self::compile_next(self.model.get(&begin_state).unwrap().clone()); | |
| self.begin_choices = choices; | |
| self.begin_weights = cum; | |
| } | |
| pub fn r#move(&self, state: &State) -> String { | |
| let (mut choices, mut cumdist) = (self.begin_choices.clone(), self.begin_weights.clone()); | |
| if state != &self.begin_state() { | |
| // FIXME: This is flawed | |
| choices.clear(); | |
| cumdist.clear(); | |
| let mut weights: Vec<i32> = Vec::new(); | |
| for (word, weight) in self.model.get(state).unwrap() { | |
| choices.push(word.clone()); | |
| weights.push(*weight); | |
| } | |
| cumdist = Self::accumulate(weights); | |
| } | |
| let r: f32 = rand::thread_rng().gen_range(0.0..1.0) * (*cumdist.last().unwrap() as f32); | |
| let r_i32 = r as i32; | |
| choices[Self::bisect_right(&cumdist, &r_i32)].clone() | |
| } | |
| pub fn r#gen(&self, init_state: Option<State>) -> Vec<String> { | |
| let mut state = init_state.unwrap_or(self.begin_state()); | |
| let mut result: Vec<String> = Vec::new(); | |
| loop { | |
| let next_word: String = self.r#move(&state); | |
| if next_word == END { | |
| break; | |
| } | |
| // FIXME: clones | |
| result.push(next_word.clone()); | |
| state = state[1..].to_vec(); | |
| state.push(next_word.clone()); | |
| } | |
| result | |
| } | |
| } | |
| #[derive(Debug, Default)] | |
| pub struct Text { | |
| reject: Option<Regex>, | |
| parsed_sentences: Vec<Vec<String>>, | |
| rejoined_text: String, | |
| pub chain: Chain, | |
| } | |
| impl Text { | |
| pub fn sentence_input(&self, s: &str) -> bool { | |
| if s.trim().is_empty() { | |
| return false; | |
| } | |
| let decoded = unidecode::unidecode(s); | |
| if let Some(re) = &self.reject | |
| && re.is_match(&decoded) | |
| { | |
| return false; | |
| } | |
| true | |
| } | |
| } | |
| impl Text { | |
| pub fn new(data: String) -> Self { | |
| let mut text = Text::default(); | |
| text.reject = Regex::new(&format!(r"(^')|('$)|\s'|'\s|[\{}(\(\)\[\])]", '"')).ok(); | |
| text.parsed_sentences = text.parse(data); | |
| text.rejoined_text = text | |
| .parsed_sentences | |
| .iter() | |
| .map(|s| s.join(" ")) | |
| .collect::<Vec<String>>() | |
| .join(" "); | |
| text.chain = Chain::new(text.parsed_sentences.clone()); | |
| text | |
| } | |
| fn verify(&self, words: &[String], mor: f32, mot: usize) -> bool { | |
| let overlap_ratio = (mor * words.len() as f32).round() as usize; | |
| let overlap_max = mot.min(overlap_ratio); | |
| let overlap_over = overlap_max + 1; | |
| let gram_count = (words.len().saturating_sub(overlap_max)).max(1); | |
| for i in 0..gram_count { | |
| let end = (i + overlap_over).min(words.len()); | |
| let gram = &words[i..end]; | |
| let gram_joined = gram.join(" "); | |
| if self.rejoined_text.contains(&gram_joined) { | |
| return false; | |
| } | |
| } | |
| true | |
| } | |
| pub fn parse(&self, data: String) -> Vec<Vec<String>> { | |
| data.split("\n") | |
| .filter(|s| self.sentence_input(s)) | |
| .map(|s| s.split_whitespace().map(|w| w.to_string()).collect()) | |
| .collect() | |
| } | |
| pub fn generate(&self, tries: i32, min_words: i32, max_words: i32) -> String { | |
| for _ in 0..tries { | |
| let words: Vec<String> = self.chain.r#gen(None); | |
| if words.len() > max_words as usize || words.len() < min_words as usize { | |
| continue; | |
| } | |
| if self.verify(&words, 0.7, 15) { | |
| return words.join(" "); | |
| } | |
| } | |
| String::new() | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment