Created
February 16, 2025 03:51
-
-
Save skyzh/ac0c990dd616d057fbc80587e20c6868 to your computer and use it in GitHub Desktop.
non-deterministic programming in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(coroutines, coroutine_trait, coroutine_clone)] | |
#![feature(type_alias_impl_trait)] | |
use std::any::Any; | |
use std::collections::VecDeque; | |
use std::ops::Range; | |
use std::{ops::Coroutine, sync::Arc}; | |
use std::{ops::CoroutineState, pin::Pin}; | |
#[derive(Debug, Clone)] | |
enum Tree<T> { | |
Node(T, Arc<[Tree<T>]>), | |
Leaf(T), | |
} | |
macro_rules! any_child { | |
($children:ident) => { | |
yield $children | |
}; | |
} | |
fn compute_sum_bfs(root: Tree<u64>) -> Vec<u64> { | |
let mut states = VecDeque::new(); | |
let mut results = Vec::new(); | |
states.push_back((0, root)); | |
while let Some((sum, root)) = states.pop_front() { | |
match root { | |
Tree::Node(val, children) => { | |
for child in children.iter() { | |
states.push_back((sum + val, child.clone())); | |
} | |
} | |
Tree::Leaf(val) => { | |
results.push(sum + val); | |
} | |
} | |
} | |
results | |
} | |
fn number_generator(n: usize) -> impl Coroutine<usize, Yield = usize, Return = usize> { | |
#[coroutine] | |
move |x: usize| { | |
println!("x={}", x); | |
let mut sum = 0; | |
for i in 0..n { | |
let x = yield i; | |
println!("x={}", x); | |
sum += i; | |
} | |
sum | |
} | |
} | |
fn compute_sum_task() -> impl Coroutine<Tree<u64>, Yield = Arc<[Tree<u64>]>, Return = u64> + Clone { | |
#[coroutine] | |
|mut next: Tree<u64>| { | |
let mut sum = 0; | |
loop { | |
match next { | |
Tree::Node(val, children) => { | |
sum += val; | |
next = any_child!(children); | |
} | |
Tree::Leaf(val) => { | |
sum += val; | |
break sum; | |
} | |
} | |
} | |
} | |
} | |
fn compute_sum_driver(root: Tree<u64>) -> Vec<u64> { | |
let mut states = VecDeque::new(); | |
states.push_back((root, compute_sum_task())); | |
let mut results = Vec::<u64>::new(); | |
// pop_front: breadth-first search; pop_back: depth-first search | |
while let Some((root, mut state)) = states.pop_front() { | |
let state_ = Pin::new(&mut state); | |
match state_.resume(root) { | |
CoroutineState::Yielded(children) => { | |
for child in children.iter() { | |
states.push_back((child.clone(), state.clone())); | |
} | |
} | |
CoroutineState::Complete(result) => { | |
results.push(result); | |
} | |
} | |
} | |
results | |
} | |
macro_rules! next_position { | |
($range:expr) => { | |
yield $range | |
}; | |
} | |
fn eight_queens_task( | |
n: usize, | |
) -> impl Coroutine<usize, Yield = Range<usize>, Return = Option<Vec<usize>>> + Clone { | |
#[coroutine] | |
move |first_row_placement: usize| { | |
let mut queens = vec![first_row_placement]; | |
for row in 1..n { | |
let col = next_position!(0..n); | |
if (0..row).all(|i| { | |
queens[i] != col | |
&& queens[i] as isize != col as isize - (row - i) as isize | |
&& queens[i] as isize != col as isize + (row - i) as isize | |
}) { | |
queens.push(col); | |
} else { | |
return None; | |
} | |
} | |
Some(queens) | |
} | |
} | |
fn eight_queens_driver(n: usize) -> Vec<Vec<usize>> { | |
let mut states = VecDeque::new(); | |
for i in 0..n { | |
states.push_back((i, eight_queens_task(n))); | |
} | |
let mut results = Vec::<Vec<usize>>::new(); | |
// pop_front: breadth-first search; pop_back: depth-first search | |
while let Some((input, mut state)) = states.pop_front() { | |
let state_ = Pin::new(&mut state); | |
match state_.resume(input) { | |
CoroutineState::Yielded(children) => { | |
for child in children { | |
states.push_back((child, state.clone())); | |
} | |
} | |
CoroutineState::Complete(result) => { | |
if let Some(result) = result { | |
results.push(result); | |
} | |
} | |
} | |
} | |
results | |
} | |
#[derive(Debug, Clone)] | |
enum Token { | |
Int(usize), | |
Add, | |
Mul, | |
LParen, | |
RParen, | |
} | |
#[derive(Debug)] | |
enum ActionResult { | |
Initial, | |
ForkLeft, | |
ForkRight, | |
ConsumeToken(Token), | |
Result(usize), | |
} | |
enum Action { | |
ParseExpr, | |
ParseTerm, | |
ParseFactor, | |
ConsumeToken(std::mem::Discriminant<Token>), | |
Fork, | |
} | |
macro_rules! parse_expr { | |
() => {{ | |
let ActionResult::Result(res) = yield Action::ParseExpr else { | |
unreachable!() | |
}; | |
res | |
}}; | |
} | |
macro_rules! parse_term { | |
() => {{ | |
let ActionResult::Result(res) = yield Action::ParseTerm else { | |
unreachable!() | |
}; | |
res | |
}}; | |
} | |
macro_rules! parse_factor { | |
() => {{ | |
let ActionResult::Result(res) = yield Action::ParseFactor else { | |
unreachable!() | |
}; | |
res | |
}}; | |
} | |
macro_rules! consume_token { | |
($token:expr) => {{ | |
let ActionResult::ConsumeToken(token) = | |
yield Action::ConsumeToken(std::mem::discriminant(&$token)) | |
else { | |
unreachable!() | |
}; | |
token | |
}}; | |
} | |
macro_rules! fork { | |
() => { | |
match (yield Action::Fork) { | |
ActionResult::ForkLeft => false, | |
ActionResult::ForkRight => true, | |
_ => unreachable!(), | |
} | |
}; | |
} | |
// F -> (E) | int | |
fn parse_factor_task() -> ParseFactorTask { | |
#[coroutine] | |
move |_| { | |
if fork!() { | |
consume_token!(Token::LParen); | |
let res = parse_expr!(); | |
consume_token!(Token::RParen); | |
res | |
} else { | |
let Token::Int(res) = consume_token!(Token::Int(0)) else { | |
unreachable!() | |
}; | |
res | |
} | |
} | |
} | |
// T -> F * T | F | |
fn parse_term_task() -> ParseTermTask { | |
#[coroutine] | |
move |_| { | |
if fork!() { | |
let right = parse_factor!(); | |
consume_token!(Token::Mul); | |
let left = parse_term!(); | |
left * right | |
} else { | |
let res = parse_factor!(); | |
res | |
} | |
} | |
} | |
// E -> T + E | T | |
fn parse_expr_task() -> ParseExprTask { | |
#[coroutine] | |
move |_| { | |
if fork!() { | |
let right = parse_term!(); | |
consume_token!(Token::Add); | |
let left = parse_expr!(); | |
left + right | |
} else { | |
let res = parse_term!(); | |
res | |
} | |
} | |
} | |
type ParseExprTask = impl Coroutine<ActionResult, Yield = Action, Return = usize> + Clone; | |
type ParseFactorTask = impl Coroutine<ActionResult, Yield = Action, Return = usize> + Clone; | |
type ParseTermTask = impl Coroutine<ActionResult, Yield = Action, Return = usize> + Clone; | |
trait ParseTask: Coroutine<ActionResult, Yield = Action, Return = usize> + Any { | |
fn dyn_clone(&self) -> Pin<Box<dyn ParseTask>>; | |
} | |
impl<T: Coroutine<ActionResult, Yield = Action, Return = usize> + Any + Clone> ParseTask for T { | |
fn dyn_clone(&self) -> Pin<Box<dyn ParseTask>> { | |
Box::pin(self.clone()) | |
} | |
} | |
fn clone_stack(stack: &[Pin<Box<dyn ParseTask>>]) -> Vec<Pin<Box<dyn ParseTask>>> { | |
stack.iter().map(|s| s.dyn_clone()).collect() | |
} | |
fn parse_driver(tokens: Vec<Token>) -> Result<usize, &'static str> { | |
let mut states: VecDeque<(ActionResult, Vec<Token>, Vec<Pin<Box<dyn ParseTask>>>)> = VecDeque::new(); | |
states.push_back(( | |
ActionResult::Initial, | |
tokens, | |
vec![Box::pin(parse_expr_task()) as Pin<Box<dyn ParseTask>>], | |
)); | |
while let Some((input, tokens, mut stack)) = states.pop_front() { | |
let state_ = Pin::new(&mut *stack.last_mut().unwrap()); | |
match state_.resume(input) { | |
CoroutineState::Yielded(action) => match action { | |
Action::Fork => { | |
states.push_back((ActionResult::ForkLeft, tokens.clone(), clone_stack(&stack))); | |
states.push_back((ActionResult::ForkRight, tokens, stack)); | |
} | |
Action::ConsumeToken(token) => { | |
let Some((a, rest)) = tokens.split_first() else { | |
continue; | |
}; | |
if token == std::mem::discriminant(a) { | |
states.push_back(( | |
ActionResult::ConsumeToken(a.clone()), | |
rest.to_vec(), | |
stack, | |
)); | |
} | |
} | |
Action::ParseExpr => { | |
let mut stack = stack; | |
stack.push(Box::pin(parse_expr_task())); | |
states.push_back((ActionResult::Initial, tokens, stack)); | |
} | |
Action::ParseTerm => { | |
let mut stack = stack; | |
stack.push(Box::pin(parse_term_task())); | |
states.push_back((ActionResult::Initial, tokens, stack)); | |
} | |
Action::ParseFactor => { | |
let mut stack = stack; | |
stack.push(Box::pin(parse_factor_task())); | |
states.push_back((ActionResult::Initial, tokens, stack)); | |
} | |
}, | |
CoroutineState::Complete(result) => { | |
// Pass the result to the previous state in the stack | |
let mut stack = stack; | |
stack.pop(); | |
if stack.is_empty() { | |
if tokens.is_empty() { | |
return Ok(result); | |
} else { | |
continue; | |
} | |
} | |
states.push_back((ActionResult::Result(result), tokens, stack)); | |
} | |
} | |
} | |
Err("cannot parse") | |
} | |
fn main() { | |
let root = Tree::Node( | |
0, | |
vec![ | |
Tree::Node(1, vec![Tree::Leaf(2), Tree::Leaf(5)].into()), | |
Tree::Node(3, vec![Tree::Leaf(4)].into()), | |
] | |
.into(), | |
); | |
println!("bfs = {:?}", compute_sum_bfs(root.clone())); | |
println!("non_deterministic = {:?}", compute_sum_driver(root)); | |
let result = eight_queens_driver(8); | |
println!("num_solutions={}", result.len()); | |
for r in result { | |
println!("{:?}", r); | |
} | |
let result = parse_driver(vec![ | |
Token::Int(1), | |
Token::Add, | |
Token::Int(2), | |
Token::Add, | |
Token::Int(3), | |
Token::Mul, | |
Token::LParen, | |
Token::Int(4), | |
Token::Add, | |
Token::Int(5), | |
Token::RParen, | |
]) | |
.unwrap(); | |
println!("result={}", result); | |
let result = parse_driver(vec![Token::Int(1), Token::Add]).unwrap(); | |
println!("result={}", result); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment