Skip to content

Instantly share code, notes, and snippets.

@ClarkeRemy
Created October 16, 2024 06:42
Show Gist options
  • Save ClarkeRemy/ea92523aa0c769d239419d20d4ec3b21 to your computer and use it in GitHub Desktop.
Save ClarkeRemy/ea92523aa0c769d239419d20d4ec3b21 to your computer and use it in GitHub Desktop.
Yet another defunctionalization exercise from a friend's code
// this is just to satisfy the type system as I don't have the function
mod boys { pub mod micb25 { pub fn boys(_:u64,_:f64)->f64 {1.0} }}
#[derive(Clone, Copy)]
pub struct Vector3<T> {x:T,y:T,z:T}
impl Vector3<f64> {
fn norm_squared(&self)->f64 {self.x*self.x+self.y*self.y+self.z*self.z}
}
// direct style
pub /* (super) */ fn coulomb_auxiliary(t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>) -> f64 {
if t == u && u == v && v == 0 {
(-2.0 * p).powi(n) * boys::micb25::boys(n as u64, p * diff.norm_squared())
} else if t == u && u == 0 {
diff.z * coulomb_auxiliary(t, u, v - 1, n + 1, p, diff)
+ if v > 1 {
(v - 1) as f64 * coulomb_auxiliary(t, u, v - 2, n + 1, p, diff)
} else {
0.0
}
} else if t == 0 {
diff.y * coulomb_auxiliary(t, u - 1, v, n + 1, p, diff)
+ if u > 1 {
(u - 1) as f64 * coulomb_auxiliary(t, u - 2, v, n + 1, p, diff)
} else {
0.0
}
} else {
diff.x * coulomb_auxiliary(t - 1, u, v, n + 1, p, diff)
+ if t > 1 {
(t - 1) as f64 * coulomb_auxiliary(t - 2, u, v, n + 1, p, diff)
} else {
0.0
}
}
}
// // direct style made tabular
// pub /* (super) */ fn coulomb_auxiliary_(t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>) -> f64 {
// if t == u && u == v && v == 0 { (-2.0 * p).powi(n) * boys::micb25::boys(n as u64, p * diff.norm_squared())
// }
// else if t == u && u == 0 { diff.z * coulomb_auxiliary_(t , u , v - 0, n + 1, p, diff) + if v > 1 { (v - 1) as f64 * coulomb_auxiliary_(t , u , v - 2, n + 1, p, diff) } else { 0.0 }
// } else if t == 0 { diff.y * coulomb_auxiliary_(t , u - 1, v , n + 1, p, diff) + if u > 1 { (u - 1) as f64 * coulomb_auxiliary_(t , u - 2, v , n + 1, p, diff) } else { 0.0}
// } else { diff.x * coulomb_auxiliary_(t - 1, u , v , n + 1, p, diff) + if t > 1 { (t - 1) as f64 * coulomb_auxiliary_(t - 2, u , v , n + 1, p, diff) } else { 0.0 }
// }
// }
// direct style with minimal recursive branches
pub fn coulomb_auxiliary_recursive_branch_reduction (t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>) -> f64 {
if t == u && u == v && v == 0 {
return (-2.0 * p).powi(n) * boys::micb25::boys(n as u64, p * diff.norm_squared())
}
let n_ = n+1;
let ([t_1,u_1,v_1], diff_, cond) =
if t == u && u == 0 {let v_ = v-1 ; ([t , u , v - 1], diff.x, if v > 1 { Some((v_, [t , u , v - 2])) } else {None})
} else if t == 0 {let u_ = u-1 ; ([t , u - 1, v ], diff.y, if u > 1 { Some((u_, [t , u - 2, v ])) } else {None})
} else {let t_ = t-1 ; ([t - 1, u , v ], diff.z, if t > 1 { Some((t_, [t - 2, u , v ])) } else {None})
};
diff_ * coulomb_auxiliary_recursive_branch_reduction(t_1, u_1, v_1, n_, p, diff) +
if let Some((d, [t_2,u_2,v_2])) = cond {
d as f64 * coulomb_auxiliary_recursive_branch_reduction(t_2, u_2, v_2, n_, p, diff)
} else { 0.0 }
}
// CPS
pub /* (super) */ fn coulomb_auxiliary_cps(t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>) -> f64 {
pub fn loop_(t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>, cont : Box<dyn FnOnce(f64)->f64> ) -> f64 {
if t == u && u == v && v == 0 {
return cont ((-2.0 * p).powi(n) * boys::micb25::boys(n as u64, p * diff.norm_squared()))
}
let n_ = n+1;
let ([t_1,u_1,v_1], diff_, cond) =
if t == u && u == 0 {let v_ = v-1 ; ([t , u , v - 1], diff.x, if v > 1 { Some((v_, [t , u , v - 2])) } else {None})
} else if t == 0 {let u_ = u-1 ; ([t , u - 1, v ], diff.y, if u > 1 { Some((u_, [t , u - 2, v ])) } else {None})
} else {let t_ = t-1 ; ([t - 1, u , v ], diff.z, if t > 1 { Some((t_, [t - 2 , u , v ])) } else {None})
};
loop_(t_1, u_1, v_1, n_, p, diff, Box::new(move |ret1| {
let partial = diff_ * ret1;
if let Some((d_field, [t_2,u_2,v_2])) = cond {
loop_(t, u, v, n_, p, diff, Box::new(move|ret2|{
cont(partial + d_field as f64 * ret2)
} ))
} else { cont(partial) }
}))
}
loop_(t,u,v,n,p,diff, Box::new(|x|x))
}
// defunctionalized
pub fn coulomb_auxiliary_defunctionalized<const STACK_SIZE : usize>(t: i32, u: i32, v: i32, n: i32, p: f64, diff: Vector3<f64>, ) -> f64 {
// initialize the state
let mut state = Loop{ t, u, v, n, consts: Consts { p, diff }, cont: if STACK_SIZE > 0 {Vec::with_capacity(STACK_SIZE)} else {Vec::new()} }.do_();
loop {
state = match state {
State::Loop(r#loop) => r#loop.do_(),
State::Apply(apply) => apply.do_(),
State::Complete(out) => return out,
}
}
// WHERE ...
struct Consts{p: f64, diff: Vector3<f64>}
struct Defer1{ diff_ : f64, n_ :i32, cond : Option<(i32, [i32; 3])>};
struct Defer2{ partial : f64, d_field : i32};
enum Frame { Defer1(Defer1), Defer2(Defer2), }
type Cont = Vec<Frame>;
enum State { Loop(Loop), Apply(Apply), Complete(f64)}
struct Loop{t: i32, u: i32, v: i32, n: i32, consts : Consts, cont : Cont};
impl Loop {
#[inline(always)]
fn do_(self)->State {
let Loop { t, u, v, n, consts : consts @ Consts { p, diff }, mut cont } = self;
if t == u && u == v && v == 0 {
let arg = ((-2.0 * p).powi(n) * boys::micb25::boys(n as u64, p * diff.norm_squared()));
return State::Apply(Apply { cont, arg , consts })
}
let n_ = n+1;
let ([t_1,u_1,v_1], diff_, cond) =
if t == u && u == 0 {let v_ = v-1 ; ([t , u , v - 1], diff.x, if v > 1 { Some((v_, [t , u , v - 2])) } else {None})
} else if t == 0 {let u_ = u-1 ; ([t , u - 1, v ], diff.y, if u > 1 { Some((u_, [t , u - 2, v ])) } else {None})
} else {let t_ = t-1 ; ([t - 1, u , v ], diff.z, if t > 1 { Some((t_, [t - 2, u , v ])) } else {None})
};
cont.push(Frame::Defer1(Defer1 { diff_, n_, cond }));
State::Loop(Loop { t: t_1, u: u_1, v: v_1, n: n_, consts, cont })
}
}
// apply continuation `cont` on argument `arg``
struct Apply{cont : Cont, arg : f64, consts : Consts};
impl Apply {
#[inline(always)]
fn do_(self)->State {
let Apply { mut cont, arg, consts } = self;
let Some(frame) = cont.pop() else { return State::Complete(arg); };
match frame {
Frame::Defer1(Defer1 { diff_, n_, cond }) => {
let partial = diff_ * arg;
if let Some((d_field, [t_2,u_2,v_2])) = cond {
cont.push(Frame::Defer2(Defer2 { partial, d_field }));
State::Loop(Loop { t: t_2, u: u_2, v: v_2, n: n_, consts, cont })
} else {
State::Apply(Apply{cont, arg : partial, consts})
}
},
Frame::Defer2(Defer2 { partial, d_field }) => State::Apply(Apply { cont, arg: partial + d_field as f64 * arg, consts }),
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment