Created
March 13, 2020 08:20
-
-
Save srishanbhattarai/66fed4241a304ddca77902ad25e7e71d to your computer and use it in GitHub Desktop.
Basic Markov chain in Rust (clones stuff around so not the most performant, gets job done)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::error::Error; | |
use std::fs::File; | |
use std::io::{BufRead, BufReader}; | |
use std::path::Path; | |
struct MarkovChain { | |
order: usize, | |
chain: HashMap<Vec<String>, HashMap<String, usize>>, | |
freqs: HashMap<Vec<String>, usize>, | |
} | |
impl Default for MarkovChain { | |
fn default() -> Self { | |
MarkovChain::with_order(1) | |
} | |
} | |
impl MarkovChain { | |
fn with_order(order: usize) -> Self { | |
MarkovChain { | |
order, | |
chain: HashMap::new(), | |
freqs: HashMap::new(), | |
} | |
} | |
fn train_sentence(&mut self, s: String) { | |
let mut words: Vec<String> = s.split(' ').map(|s| s.to_string()).collect(); | |
dbg!(words.clone()); | |
for i in 0..(words.len() - self.order) { | |
let curr: Vec<String> = words.drain(i..(i + self.order)).collect(); | |
let next = (words[0]).clone(); | |
let entry = self.chain.entry(curr.clone()).or_insert(HashMap::new()); | |
let occurences = entry.entry(next).or_insert(0); | |
*occurences += 1; | |
let freq = self.freqs.entry(curr).or_insert(0); | |
*freq += 1; | |
} | |
} | |
pub fn train_file(&mut self, p: &Path) -> Result<(), Box<Error>> { | |
let f = File::open(&p)?; | |
let reader = BufReader::new(f); | |
for line in reader.lines() { | |
self.train_sentence(line.unwrap()); | |
} | |
Ok(()) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn it_works() { | |
let mut p = Path::new("test.txt"); | |
let mut chain: MarkovChain = Default::default(); | |
assert!(chain.train_file(p).is_ok()); | |
for (k, v) in chain.chain.into_iter() { | |
let key = k.join(","); | |
println!("Key: {}", key); | |
for (ik, iv) in v.into_iter() { | |
println!("{} = {}", ik, iv); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment