Created
October 26, 2024 04:40
-
-
Save cbeck88/3553de29d5a530fb5bdce95ef7fc0661 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use bigdecimal::{BigDecimal as Decimal, One, Zero}; | |
/// Bigdecimal crate supplies exp but not log unfortunately. | |
/// | |
/// Context for computation of ln | |
pub struct LnContext { | |
x: Decimal, | |
e_x: Decimal, | |
e_x_inv: Decimal, | |
} | |
impl LnContext { | |
/// Create a new context for computing ln of a Decimal | |
pub fn new() -> Self { | |
let x: Decimal = Decimal::one() / 5; | |
Self { | |
e_x: (x.clone()).exp(), | |
e_x_inv: (-x.clone()).exp(), | |
x, | |
} | |
} | |
/// Compute ln of a Decimal | |
pub fn ln(&self, mut arg: Decimal) -> Decimal { | |
if arg <= Decimal::zero() { | |
panic!("ln argument out of bounds"); | |
} | |
// We start by dividing out powers of e^x and adding to "adjustment", which is added to our result at the end. | |
// Intuitively, we could structure this as a recursive call instead, but this way should be more performant. | |
// | |
// The reason to do this step at all is that the maclauran series doesn't converge well except very close to 1, | |
// so we divide / multiply by e^x as long as we are not in the range [e^{-x}, e^{x}], in order to move closer | |
// to that range. Once we are in the range we can safely use maclauran series without killing perf / accuracy. | |
let mut adjustment = Decimal::zero(); | |
while arg > self.e_x { | |
arg *= self.e_x_inv.clone(); | |
adjustment += self.x.clone(); | |
} | |
while arg < self.e_x_inv { | |
arg *= self.e_x.clone(); | |
adjustment -= self.x.clone(); | |
} | |
// We now know arg is in the range [e^{-x}, e^{x}], so a maclauran series may be reasonable | |
// Drop some low order terms to make it faster | |
let x = (arg - Decimal::one()).with_scale(6); | |
let mut result = Decimal::zero(); | |
let mut x_to_two_n_plus_one = x.clone(); | |
for n in 0..10 { | |
result += x_to_two_n_plus_one.clone() / (2 * n + 1); | |
x_to_two_n_plus_one *= x.clone(); | |
result -= x_to_two_n_plus_one.clone() / (2 * n + 2); | |
x_to_two_n_plus_one *= x.clone(); | |
// Drop low order terms to make it faster | |
x_to_two_n_plus_one = x_to_two_n_plus_one.with_scale(6); | |
} | |
result + adjustment | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
fn assert_ln_accuracy(ln_context: &LnContext, arg: Decimal) { | |
let epsilon = Decimal::new(1.into(), 4); | |
let range = Decimal::one() - epsilon.clone()..Decimal::one() + epsilon; | |
{ | |
let round_trip = ln_context.ln(arg.clone()).exp(); | |
let ratio = round_trip / arg.clone(); | |
if !range.contains(&ratio) { | |
panic!("Accuracy loss on exp(ln({arg})): ratio = {ratio}"); | |
} | |
} | |
let arg = arg.inverse(); | |
{ | |
let round_trip = ln_context.ln(arg.clone()).exp(); | |
let ratio = round_trip / arg.clone(); | |
if !range.contains(&ratio) { | |
panic!("Accuracy loss on exp(ln({arg})): ratio = {ratio}"); | |
} | |
} | |
} | |
#[test] | |
fn test_ln_accuracy() { | |
let c = LnContext::new(); | |
assert_ln_accuracy(&c, Decimal::new(1.into(), 0)); | |
assert_ln_accuracy(&c, Decimal::new(2.into(), 0)); | |
assert_ln_accuracy(&c, Decimal::new(3.into(), 0)); | |
assert_ln_accuracy(&c, Decimal::new(4.into(), 0)); | |
} | |
#[test] | |
fn test_ln_accuracy_large() { | |
let c = LnContext::new(); | |
assert_ln_accuracy(&c, Decimal::new(1.into(), -1)); | |
assert_ln_accuracy(&c, Decimal::new(1.into(), -2)); | |
assert_ln_accuracy(&c, Decimal::new(1.into(), -4)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment