Last active
September 3, 2025 20:44
-
-
Save alexhallam/e3fc2c11387b89026bebf79240a587d7 to your computer and use it in GitHub Desktop.
chumsky
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use logos::Logos; | |
| use serde::{Deserialize, Serialize}; | |
| use std::collections::HashMap; | |
| use thiserror::Error; | |
| // --------------------------- | |
| // LEXER | |
| // --------------------------- | |
| #[derive(Logos, Debug, PartialEq, Clone)] | |
| #[logos(skip r"[ \t\n\f]+")] | |
| enum Token { | |
| #[token("-")] | |
| Minus, | |
| #[token("1")] | |
| One, | |
| #[regex(r"[2-9]\d*")] | |
| Integer, | |
| #[regex(r"[a-zA-Z][a-zA-Z0-9_]*")] | |
| ColumnName, | |
| #[token("~")] | |
| Tilde, | |
| #[token("+")] | |
| Plus, | |
| #[token("(")] | |
| FunctionStart, | |
| #[token(")")] | |
| FunctionEnd, | |
| #[token("poly")] | |
| Poly, | |
| #[token(",")] | |
| Comma, | |
| #[token("=")] | |
| Equal, | |
| #[token("family")] | |
| Family, | |
| #[token("gaussian")] | |
| Gaussian, | |
| #[token("binomial")] | |
| Binomial, | |
| #[token("poisson")] | |
| Poisson, | |
| } | |
| // --------------------------- | |
| // DATA STRUCTURES | |
| // --------------------------- | |
| #[derive(Debug, Serialize, Deserialize, Clone)] | |
| /// Represents the distinct column names as they were input by the user | |
| /// Example: | |
| /// "formula": "y ~ x + poly(x, 2) + poly(x1, 4) + log(x1) - 1, family = gaussian" | |
| /// { | |
| /// "id": 1, | |
| /// "name": "y" | |
| /// }, | |
| /// { | |
| /// "id": 2, | |
| /// "name": "x" | |
| /// }, | |
| /// { | |
| /// "id": 3, | |
| /// "name": "x1" | |
| /// } | |
| struct ColumnNameStruct { | |
| id: u32, | |
| name: String, | |
| } | |
| #[derive(Debug, Serialize, Deserialize, Clone)] | |
| /// Represents transformations applied to a column | |
| /// Example: | |
| /// "formula": "y ~ x + poly(x, 2) + poly(x1, 4) + log(x1) - 1, family = gaussian" | |
| /// { | |
| /// "column_name_struct_id": 2, | |
| /// "name": "poly" | |
| /// }, | |
| /// { | |
| /// "column_name_struct_id": 3, | |
| /// "name": "poly" | |
| /// }, | |
| /// { | |
| /// "column_name_struct_id": 3, | |
| /// "name": "log" | |
| /// } | |
| struct TransformationStruct { | |
| column_name_struct_id: u32, | |
| name: String, | |
| } | |
| #[derive(Debug, Serialize, Deserialize, Clone)] | |
| struct ColumnSuggestedNameStruct { | |
| column_name_struct_id: u32, | |
| name: String, | |
| } | |
| #[derive(Debug, Serialize, Deserialize, Clone)] | |
| struct FormulaMetaData { | |
| transformations: Vec<TransformationStruct>, | |
| column_names: Vec<ColumnNameStruct>, | |
| has_intercept: bool, | |
| formula: String, | |
| response_columns: Vec<ColumnSuggestedNameStruct>, | |
| fix_effects_columns: Vec<ColumnSuggestedNameStruct>, | |
| random_effects_columns: Vec<ColumnSuggestedNameStruct>, | |
| } | |
| // --------------------------- | |
| // SIMPLE AST | |
| // --------------------------- | |
| #[derive(Debug, Clone)] | |
| enum Family { | |
| Gaussian, | |
| Binomial, | |
| Poisson, | |
| } | |
| #[derive(Debug, Clone)] | |
| enum Term { | |
| Column(String), | |
| Function { name: String, args: Vec<Argument> }, | |
| } | |
| #[derive(Debug, Clone)] | |
| enum Argument { | |
| Ident(String), | |
| Integer(u32), | |
| } | |
| // --------------------------- | |
| // PARSER | |
| // --------------------------- | |
| #[derive(Error, Debug)] | |
| /// This checks for the following | |
| /// - lexing errors | |
| /// - unexpected end of input | |
| /// - unexpected tokens | |
| /// - invalid syntax | |
| enum ParseError { | |
| #[error("lexing error at {0:?}")] | |
| Lex(String), | |
| #[error("unexpected end of input")] | |
| Eoi, | |
| #[error("unexpected token: expected {expected:?}, found {found:?}")] | |
| Unexpected { expected: &'static str, found: Option<Token> }, | |
| #[error("invalid syntax: {0}")] | |
| Syntax(String), | |
| } | |
| /// Parser for the formula | |
| /// This is responsible for parsing the formula string into an AST | |
| /// The <'a> means that the parser will borrow the input string for the duration of its lifetime | |
| /// The `input` field is a reference to the original input string | |
| /// The `tokens` field is a vector of all the tokens found in the input string | |
| /// The `pos` field is the current position in the token stream | |
| struct Parser<'a> { | |
| input: &'a str, | |
| tokens: Vec<(Token, &'a str)>, | |
| pos: usize, | |
| } | |
| /// The parser implementation does the actual work of parsing the formula | |
| impl<'a> Parser<'a> { | |
| // `new` creates a new parser instance. This function initializes the lexer and token vector. | |
| // lex.next() iterates through the tokens | |
| // lex.slice() shows the current token as a string | |
| // An example of the input to `new()` is "y ~ x + poly(x, 2) + poly(x1, 4) + log(x1) - 1, family = gaussian" | |
| // the `new()` function would return a Parser instance which has the following data: | |
| // - input: a reference to the original input string | |
| // - tokens: a vector of all the tokens found in the input string | |
| // - pos: the current position in the token stream | |
| fn new(input: &'a str) -> Result<Self, ParseError> { | |
| let mut lex = Token::lexer(input); | |
| let mut tokens = Vec::new(); | |
| while let Some(item) = lex.next() { | |
| match item { | |
| Ok(tok) => { | |
| let slice = lex.slice(); | |
| tokens.push((tok, slice)); | |
| } | |
| Err(()) => { | |
| return Err(ParseError::Lex(lex.slice().to_string())); | |
| } | |
| } | |
| } | |
| Ok(Self { input, tokens, pos: 0 }) | |
| } | |
| // The `peek` function returns the next token without consuming it | |
| // It is important to look ahead without consuming because we may need to check the next token multiple times | |
| // Without `peek`, we would have to call `next` to see the next token, which would consume it. | |
| fn peek(&self) -> Option<&(Token, &'a str)> { | |
| self.tokens.get(self.pos) | |
| } | |
| // The `next` function returns the next token and consumes it | |
| // The consuming action is done by incrementing the `pos` field | |
| fn next(&mut self) -> Option<(Token, &'a str)> { | |
| let t = self.tokens.get(self.pos).cloned(); | |
| if t.is_some() { | |
| self.pos += 1; | |
| } | |
| t | |
| } | |
| // I don't get this generic thing | |
| fn matches<F>(&mut self, pred: F) -> bool | |
| where | |
| F: Fn(&Token) -> bool, | |
| { | |
| if let Some((tok, _)) = self.peek() { | |
| if pred(tok) { | |
| self.pos += 1; | |
| return true; | |
| } | |
| } | |
| false | |
| } | |
| // The `expect` function checks if the next token matches the given pattern | |
| // The `expect_fn` is a function that takes a reference to a Token and returns a boolean | |
| // It is true if the token matches the expected pattern | |
| fn expect( | |
| &mut self, | |
| expect_fn: fn(&Token) -> bool, | |
| expected: &'static str, | |
| ) -> Result<(Token, &'a str), ParseError> { | |
| if let Some((tok, slice)) = self.peek().cloned() { | |
| if expect_fn(&tok) { | |
| self.pos += 1; | |
| Ok((tok, slice)) | |
| } else { | |
| Err(ParseError::Unexpected { | |
| expected, | |
| found: Some(tok), | |
| }) | |
| } | |
| } else { | |
| Err(ParseError::Unexpected { | |
| expected, | |
| found: None, | |
| }) | |
| } | |
| } | |
| // The `parse_formula` function is the main entry point for parsing the formula | |
| // `parse_response()` is defined below. If the token is | |
| fn parse_formula(&mut self) -> Result<(String, Vec<Term>, bool, Option<Family>), ParseError> { | |
| let response = self.parse_response()?; | |
| self.expect(|t| matches!(t, Token::Tilde), "~")?; | |
| let (terms, has_intercept) = self.parse_rhs()?; | |
| let mut family = None; | |
| if self.matches(|t| matches!(t, Token::Comma)) { | |
| self.expect(|t| matches!(t, Token::Family), "family")?; | |
| self.expect(|t| matches!(t, Token::Equal), "=")?; | |
| family = Some(self.parse_family()?); | |
| } | |
| Ok((response, terms, has_intercept, family)) | |
| } | |
| fn parse_response(&mut self) -> Result<String, ParseError> { | |
| let (_, name) = self.expect(|t| matches!(t, Token::ColumnName), "ColumnName")?; | |
| Ok(name.to_string()) | |
| } | |
| fn parse_rhs(&mut self) -> Result<(Vec<Term>, bool), ParseError> { | |
| let mut terms = Vec::new(); | |
| let mut has_intercept = true; | |
| if self.peek().is_some() && !matches!(self.peek().unwrap().0, Token::Comma) { | |
| terms.push(self.parse_term()?); | |
| } | |
| while self.matches(|t| matches!(t, Token::Plus)) { | |
| terms.push(self.parse_term()?); | |
| } | |
| if self.matches(|t| matches!(t, Token::Minus)) { | |
| if self.matches(|t| matches!(t, Token::One)) { | |
| has_intercept = false; | |
| } else { | |
| return Err(ParseError::Syntax( | |
| "expected '1' after '-' to remove intercept".into(), | |
| )); | |
| } | |
| } | |
| Ok((terms, has_intercept)) | |
| } | |
| fn parse_term(&mut self) -> Result<Term, ParseError> { | |
| let (tok, name_slice) = self.expect( | |
| |t| matches!(t, Token::Poly | Token::ColumnName), | |
| "Poly or ColumnName", | |
| )?; | |
| if self.matches(|t| matches!(t, Token::FunctionStart)) { | |
| let fname = match tok { | |
| Token::Poly => "poly".to_string(), | |
| Token::ColumnName => name_slice.to_string(), | |
| _ => unreachable!(), | |
| }; | |
| let args = self.parse_arg_list()?; | |
| self.expect(|t| matches!(t, Token::FunctionEnd), ")")?; | |
| Ok(Term::Function { name: fname, args }) | |
| } else { | |
| match tok { | |
| Token::ColumnName => Ok(Term::Column(name_slice.to_string())), | |
| Token::Poly => Err(ParseError::Syntax("expected '(' after 'poly'".into())), | |
| _ => Err(ParseError::Unexpected { | |
| expected: "term", | |
| found: Some(tok), | |
| }), | |
| } | |
| } | |
| } | |
| fn parse_arg_list(&mut self) -> Result<Vec<Argument>, ParseError> { | |
| let mut args = Vec::new(); | |
| if let Some((tok, _)) = self.peek().cloned() { | |
| if matches!(tok, Token::FunctionEnd) { | |
| return Ok(args); | |
| } | |
| } | |
| args.push(self.parse_arg()?); | |
| while self.matches(|t| matches!(t, Token::Comma)) { | |
| args.push(self.parse_arg()?); | |
| } | |
| Ok(args) | |
| } | |
| fn parse_arg(&mut self) -> Result<Argument, ParseError> { | |
| if let Some((tok, slice)) = self.peek().cloned() { | |
| match tok { | |
| Token::ColumnName => { | |
| self.next(); | |
| Ok(Argument::Ident(slice.to_string())) | |
| } | |
| Token::Integer => { | |
| self.next(); | |
| Ok(Argument::Integer(slice.parse().unwrap())) | |
| } | |
| Token::One => { | |
| self.next(); | |
| Ok(Argument::Integer(1)) | |
| } | |
| _ => Err(ParseError::Unexpected { | |
| expected: "argument", | |
| found: Some(tok), | |
| }), | |
| } | |
| } else { | |
| Err(ParseError::Eoi) | |
| } | |
| } | |
| fn parse_family(&mut self) -> Result<Family, ParseError> { | |
| let (tok, _) = self.expect( | |
| |t| matches!(t, Token::Gaussian | Token::Binomial | Token::Poisson), | |
| "gaussian | binomial | poisson", | |
| )?; | |
| let fam = match tok { | |
| Token::Gaussian => Family::Gaussian, | |
| Token::Binomial => Family::Binomial, | |
| Token::Poisson => Family::Poisson, | |
| _ => unreachable!(), | |
| }; | |
| Ok(fam) | |
| } | |
| } | |
| // --------------------------- | |
| // META BUILDER | |
| // --------------------------- | |
| #[derive(Default)] | |
| struct MetaBuilder { | |
| name_to_id: HashMap<String, u32>, | |
| columns: Vec<ColumnNameStruct>, | |
| transformations: Vec<TransformationStruct>, | |
| response_cols: Vec<ColumnSuggestedNameStruct>, | |
| fixed_cols: Vec<ColumnSuggestedNameStruct>, | |
| random_cols: Vec<ColumnSuggestedNameStruct>, | |
| } | |
| impl MetaBuilder { | |
| fn new() -> Self { | |
| Self::default() | |
| } | |
| fn ensure_col(&mut self, name: &str) -> u32 { | |
| if let Some(&id) = self.name_to_id.get(name) { | |
| return id; | |
| } | |
| let id = self.columns.len() as u32 + 1; | |
| self.columns.push(ColumnNameStruct { | |
| id, | |
| name: name.to_string(), | |
| }); | |
| self.name_to_id.insert(name.to_string(), id); | |
| id | |
| } | |
| fn push_response(&mut self, name: &str) { | |
| let id = self.ensure_col(name); | |
| self.response_cols.push(ColumnSuggestedNameStruct { | |
| column_name_struct_id: id, | |
| name: name.to_string(), | |
| }); | |
| } | |
| fn push_plain_term(&mut self, name: &str) { | |
| let id = self.ensure_col(name); | |
| self.fixed_cols.push(ColumnSuggestedNameStruct { | |
| column_name_struct_id: id, | |
| name: name.to_string(), | |
| }); | |
| } | |
| fn push_function_term(&mut self, fname: &str, args: &[Argument]) { | |
| let base_ident = args.iter().find_map(|a| match a { | |
| Argument::Ident(s) => Some(s.as_str()), | |
| _ => None, | |
| }); | |
| let base_id = base_ident.map(|col| self.ensure_col(col)).unwrap_or(0); | |
| let arg_str = args | |
| .iter() | |
| .map(|a| match a { | |
| Argument::Ident(s) => s.clone(), | |
| Argument::Integer(n) => n.to_string(), | |
| }) | |
| .collect::<Vec<_>>() | |
| .join(", "); | |
| let suggested = format!("{fname}({arg_str})"); | |
| if base_id != 0 { | |
| self.transformations.push(TransformationStruct { | |
| column_name_struct_id: base_id, | |
| name: fname.to_string(), | |
| }); | |
| self.fixed_cols.push(ColumnSuggestedNameStruct { | |
| column_name_struct_id: base_id, | |
| name: suggested, | |
| }); | |
| } else { | |
| self.fixed_cols.push(ColumnSuggestedNameStruct { | |
| column_name_struct_id: 0, | |
| name: suggested, | |
| }); | |
| } | |
| } | |
| fn build(self, input: &str, has_intercept: bool) -> FormulaMetaData { | |
| FormulaMetaData { | |
| transformations: self.transformations, | |
| column_names: self.columns, | |
| has_intercept, | |
| formula: input.to_string(), | |
| response_columns: self.response_cols, | |
| fix_effects_columns: self.fixed_cols, | |
| random_effects_columns: self.random_cols, | |
| } | |
| } | |
| } | |
| // --------------------------- | |
| // DEMO MAIN | |
| // --------------------------- | |
| fn main() -> Result<(), Box<dyn std::error::Error>> { | |
| let input = "y ~ x + poly(x, 2) + poly(x1, 4) + log(x1) - 1, family = gaussian"; | |
| println!("TOKENS:"); | |
| let mut lex = Token::lexer(input); | |
| while let Some(item) = lex.next() { | |
| match item { | |
| Ok(tok) => println!("{:?}: {}", tok, lex.slice()), | |
| Err(()) => println!("LEX ERROR at {:?}", lex.slice()), | |
| } | |
| } | |
| println!(); | |
| let mut p = Parser::new(input)?; | |
| let (response, terms, has_intercept, family_opt) = p.parse_formula()?; | |
| let mut mb = MetaBuilder::new(); | |
| mb.push_response(&response); | |
| for t in terms { | |
| match t { | |
| Term::Column(name) => mb.push_plain_term(&name), | |
| Term::Function { name, args } => mb.push_function_term(&name, &args), | |
| } | |
| } | |
| let meta = mb.build(input, has_intercept); | |
| println!("FAMILY (parsed, not stored): {:?}", family_opt); | |
| println!("FORMULA METADATA:"); | |
| println!("{}", serde_json::to_string_pretty(&meta)?); | |
| Ok(()) | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment