Last active
October 15, 2024 15:34
-
-
Save ClarkeRemy/9ea7f691333220c362810a50795098fd to your computer and use it in GitHub Desktop.
Hermite expansion defunctionalization for a friend
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// original | |
// fn hermite_expansion(i: i32, j: i32, t: i32, dist: f64, a: f64, b: f64) -> f64 { | |
// let p = a + b; | |
// let q = a * b / p; | |
// if t < 0 || t > i + j { | |
// 0.0 | |
// } else if i == j && j == t && t == 0 { | |
// f64::exp(-q * dist.powi(2)) | |
// } else if j == 0 { | |
// (2.0 * p).recip() * hermite_expansion(i - 1, j, t - 1, dist, a, b) | |
// - (q * dist / a) * hermite_expansion(i - 1, j, t, dist, a, b) | |
// + (t + 1) as f64 * hermite_expansion(i - 1, j, t + 1, dist, a, b) | |
// } else { | |
// (2.0 * p).recip() * hermite_expansion(i, j - 1, t - 1, dist, a, b) | |
// + (q * dist / b) * hermite_expansion(i, j - 1, t, dist, a, b) | |
// + (t + 1) as f64 * hermite_expansion(i, j - 1, t + 1, dist, a, b) | |
// } | |
// } | |
fn hermite_expansion(i: i32, j: i32, t: i32, a : f64, b : f64, dist :f64 ) -> f64 { | |
let mut state = init(i, j, t, a, b, dist); | |
loop { | |
state = match state { | |
Process::Running(state) => step(state), | |
Process::Complete(ret) => return ret, | |
} | |
} | |
// ... WHERE | |
type Cont = Vec<Frame>; | |
enum Frame { | |
Defer1(Defer1), | |
Defer2(Defer2), | |
Combine(Combine), | |
} | |
struct Defer1 { c2_c3 : (f64, f64), i_j : (i32, i32), t2_t3 : (i32, i32), } | |
struct Defer2 { partial1: f64, c2_c3 : (f64, f64), i_j : (i32, i32), t3 :i32, } | |
struct Combine { partial2: f64, c3 : f64, } | |
struct Consts { c1 : f64, c2a : f64, c2b : f64, non_zero_basecase : f64, } | |
struct Apply { cont : Cont, arg : f64, consts : Consts } | |
struct Loop { i_j : (i32,i32), t: i32, consts : Consts, cont : Cont } | |
enum Process {Running(State), Complete(f64)} | |
enum State { Loop(Loop), Apply(Apply),} | |
fn init(i: i32, j: i32, t: i32, a : f64, b : f64, dist :f64)->Process { | |
// partial eval | |
let p = a + b; | |
let q = a * b / p; | |
let c1 = (2.0 * p).recip(); | |
let c2a = -(q * dist / a); | |
let c2b = q * dist / b; | |
let non_zero_basecase = f64::exp(-q * dist.powi(2)); | |
let consts = Consts { c1, c2a, c2b, non_zero_basecase}; | |
Process::Running(State::Loop(Loop { i_j: (i,j), t, consts, cont: Vec::new() })) | |
} | |
fn step( state : State ) -> Process { | |
Process::Running(match state { | |
State::Loop(Loop { i_j : (i,j), t, consts, mut cont }) => | |
if t < 0 || t > i + j { | |
State::Apply(Apply { cont, arg: 0.0, consts}) | |
} else if i == j && j == t && t == 0 { | |
State::Apply(Apply { cont, arg: consts.non_zero_basecase, consts}) | |
} else { | |
let (t1,t2,t3) = (t-1,t,t+1); | |
let c3 = (t + 1) as f64; | |
let (c2, i_j) = if j == 0 { (consts.c2a, (i-1, j)) } else { (consts.c2b, (i, j -1)) }; | |
cont.push(Frame::Defer1(Defer1 { c2_c3: (c2, c3), i_j, t2_t3: (t2, t3), })); | |
State::Loop(Loop { i_j, t: t1, consts, cont}) | |
}, | |
State::Apply(Apply { mut cont, arg, consts }) => match if let Option::Some(frame) = cont.pop() {frame} else {return Process::Complete(arg)} { | |
Frame::Defer1(Defer1 { c2_c3, i_j, t2_t3 : (t2,t3) }) => { | |
let c1 = consts.c1; | |
cont.push(Frame::Defer2(Defer2 { partial1: c1*arg, c2_c3, i_j, t3})); | |
State::Loop(Loop { i_j, t: t2, consts, cont }) | |
} | |
Frame::Defer2(Defer2 { partial1, c2_c3 : (c2,c3), i_j, t3 }) => { | |
cont.push(Frame::Combine(Combine { partial2: partial1+c2*arg, c3 }) ); | |
State::Loop(Loop { i_j, t: t3, consts, cont }) | |
} | |
Frame::Combine(Combine { partial2, c3 }) => State::Apply(Apply { cont, arg: partial2 + c3*arg, consts }), | |
}, | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment