Skip to content

Instantly share code, notes, and snippets.

@cbeck88
Created October 26, 2024 04:40
Show Gist options
  • Save cbeck88/3553de29d5a530fb5bdce95ef7fc0661 to your computer and use it in GitHub Desktop.
Save cbeck88/3553de29d5a530fb5bdce95ef7fc0661 to your computer and use it in GitHub Desktop.
use bigdecimal::{BigDecimal as Decimal, One, Zero};
/// Bigdecimal crate supplies exp but not log unfortunately.
///
/// Context for computation of ln
pub struct LnContext {
x: Decimal,
e_x: Decimal,
e_x_inv: Decimal,
}
impl LnContext {
/// Create a new context for computing ln of a Decimal
pub fn new() -> Self {
let x: Decimal = Decimal::one() / 5;
Self {
e_x: (x.clone()).exp(),
e_x_inv: (-x.clone()).exp(),
x,
}
}
/// Compute ln of a Decimal
pub fn ln(&self, mut arg: Decimal) -> Decimal {
if arg <= Decimal::zero() {
panic!("ln argument out of bounds");
}
// We start by dividing out powers of e^x and adding to "adjustment", which is added to our result at the end.
// Intuitively, we could structure this as a recursive call instead, but this way should be more performant.
//
// The reason to do this step at all is that the maclauran series doesn't converge well except very close to 1,
// so we divide / multiply by e^x as long as we are not in the range [e^{-x}, e^{x}], in order to move closer
// to that range. Once we are in the range we can safely use maclauran series without killing perf / accuracy.
let mut adjustment = Decimal::zero();
while arg > self.e_x {
arg *= self.e_x_inv.clone();
adjustment += self.x.clone();
}
while arg < self.e_x_inv {
arg *= self.e_x.clone();
adjustment -= self.x.clone();
}
// We now know arg is in the range [e^{-x}, e^{x}], so a maclauran series may be reasonable
// Drop some low order terms to make it faster
let x = (arg - Decimal::one()).with_scale(6);
let mut result = Decimal::zero();
let mut x_to_two_n_plus_one = x.clone();
for n in 0..10 {
result += x_to_two_n_plus_one.clone() / (2 * n + 1);
x_to_two_n_plus_one *= x.clone();
result -= x_to_two_n_plus_one.clone() / (2 * n + 2);
x_to_two_n_plus_one *= x.clone();
// Drop low order terms to make it faster
x_to_two_n_plus_one = x_to_two_n_plus_one.with_scale(6);
}
result + adjustment
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_ln_accuracy(ln_context: &LnContext, arg: Decimal) {
let epsilon = Decimal::new(1.into(), 4);
let range = Decimal::one() - epsilon.clone()..Decimal::one() + epsilon;
{
let round_trip = ln_context.ln(arg.clone()).exp();
let ratio = round_trip / arg.clone();
if !range.contains(&ratio) {
panic!("Accuracy loss on exp(ln({arg})): ratio = {ratio}");
}
}
let arg = arg.inverse();
{
let round_trip = ln_context.ln(arg.clone()).exp();
let ratio = round_trip / arg.clone();
if !range.contains(&ratio) {
panic!("Accuracy loss on exp(ln({arg})): ratio = {ratio}");
}
}
}
#[test]
fn test_ln_accuracy() {
let c = LnContext::new();
assert_ln_accuracy(&c, Decimal::new(1.into(), 0));
assert_ln_accuracy(&c, Decimal::new(2.into(), 0));
assert_ln_accuracy(&c, Decimal::new(3.into(), 0));
assert_ln_accuracy(&c, Decimal::new(4.into(), 0));
}
#[test]
fn test_ln_accuracy_large() {
let c = LnContext::new();
assert_ln_accuracy(&c, Decimal::new(1.into(), -1));
assert_ln_accuracy(&c, Decimal::new(1.into(), -2));
assert_ln_accuracy(&c, Decimal::new(1.into(), -4));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment