Last active
October 18, 2025 10:09
-
-
Save Caellian/98da404e2288eb1522ae007b33dce0bd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use rand_distr::{Distribution, Normal, StandardNormal, num_traits::Float, uniform::SampleUniform}; | |
| trait TruncationDetail { | |
| type Value; | |
| } | |
| impl<F: Float> TruncationDetail for Normal<F> | |
| where | |
| StandardNormal: Distribution<F>, | |
| { | |
| type Value = std::ops::Range<F>; | |
| } | |
| pub struct Truncated<T, D: Distribution<T>> | |
| where | |
| D: TruncationDetail, | |
| { | |
| inner: D, | |
| bounds: std::ops::Range<T>, | |
| detail: <D as TruncationDetail>::Value, | |
| } | |
| pub trait Truncation<T>: Sized + Distribution<T> + TruncationDetail { | |
| fn truncate(self, bounds: std::ops::Range<T>) -> Truncated<T, Self>; | |
| } | |
| impl<F: Float> Truncation<F> for Normal<F> | |
| where | |
| StandardNormal: Distribution<F>, | |
| { | |
| fn truncate(self, bounds: std::ops::Range<F>) -> Truncated<F, Self> { | |
| // Normalize bounds | |
| let alpha = (bounds.start - self.mean()) / self.std_dev(); | |
| let beta = (bounds.end - self.mean()) / self.std_dev(); | |
| // Compute their CDF values | |
| let phi_a = normal_cdf_standard(alpha); | |
| let phi_b = normal_cdf_standard(beta); | |
| Truncated { | |
| inner: self, | |
| bounds, | |
| detail: phi_a..phi_b, | |
| } | |
| } | |
| } | |
| /// Abramowitz-Stegun approximation of the error function erf(x) | |
| /// | |
| /// accurate to ~7 decimal digits | |
| fn erf<F: Float>(x: F) -> F { | |
| let x = x.to_f64().unwrap(); | |
| // constants | |
| let a1 = 0.254829592; | |
| let a2 = -0.284496736; | |
| let a3 = 1.421413741; | |
| let a4 = -1.453152027; | |
| let a5 = 1.061405429; | |
| let p = 0.3275911; | |
| // Save the sign of x | |
| let sign = if x < 0.0 { -1.0 } else { 1.0 }; | |
| let x = x.abs(); | |
| // Abramowitz & Stegun formula 7.1.26 | |
| let t = 1.0 / (1.0 + p * x); | |
| let y = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) * (-x * x).exp(); | |
| F::from(sign * y).unwrap() | |
| } | |
| /// Standard normal CDF Φ(x) | |
| fn normal_cdf_standard<F: Float>(x: F) -> F { | |
| F::from(0.5).unwrap() * (F::one() + erf(x / F::from(std::f64::consts::SQRT_2).unwrap())) | |
| } | |
| /// Normal CDF for N(mu, sigma²) | |
| fn normal_cdf<F: Float>(x: F, mu: F, sigma: F) -> F { | |
| normal_cdf_standard((x - mu) / sigma) | |
| } | |
| /// Inverse CDF (quantile) for the standard normal distribution. | |
| /// | |
| /// Peter John Acklam’s algorithm (2003) | |
| /// | |
| /// Accurate to ~1e-9 for all p in (0,1) | |
| #[allow(clippy::excessive_precision)] | |
| pub fn normal_quantile_standard<F: Float>(p: F) -> F { | |
| let p = p.to_f64().unwrap(); | |
| assert!(p > 0.0 && p < 1.0, "p must be in (0, 1)"); | |
| // Coefficients in rational approximations | |
| const A: [f64; 6] = [ | |
| -3.969683028665376e+01, | |
| 2.209460984245205e+02, | |
| -2.759285104469687e+02, | |
| 1.383577518672690e+02, | |
| -3.066479806614716e+01, | |
| 2.506628277459239e+00, | |
| ]; | |
| const B: [f64; 5] = [ | |
| -5.447609879822406e+01, | |
| 1.615858368580409e+02, | |
| -1.556989798598866e+02, | |
| 6.680131188771972e+01, | |
| -1.328068155288572e+01, | |
| ]; | |
| const C: [f64; 6] = [ | |
| -7.784894002430293e-03, | |
| -3.223964580411365e-01, | |
| -2.400758277161838e+00, | |
| -2.549732539343734e+00, | |
| 4.374664141464968e+00, | |
| 2.938163982698783e+00, | |
| ]; | |
| const D: [f64; 4] = [ | |
| 7.784695709041462e-03, | |
| 3.224671290700398e-01, | |
| 2.445134137142996e+00, | |
| 3.754408661907416e+00, | |
| ]; | |
| // Define break-points | |
| const P_LOW: f64 = 0.02425; | |
| const P_HIGH: f64 = 1.0 - P_LOW; | |
| let result = if p < P_LOW { | |
| // Lower tail | |
| let q = (-2.0 * p.ln()).sqrt(); | |
| -((((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5]) | |
| / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)) | |
| } else if p <= P_HIGH { | |
| // Central region | |
| let q = p - 0.5; | |
| let r = q * q; | |
| (((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q | |
| / (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0) | |
| } else { | |
| // Upper tail | |
| let q = (-2.0 * (1.0 - p).ln()).sqrt(); | |
| (((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5]) | |
| / ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0) | |
| }; | |
| F::from(result).unwrap() | |
| } | |
| /// Inverse CDF for a general normal N(mu, sigma²) | |
| pub fn normal_quantile<F: Float>(p: F, mu: F, sigma: F) -> F { | |
| mu + sigma * normal_quantile_standard(p) | |
| } | |
| impl<F: Float> Distribution<F> for Truncated<F, Normal<F>> | |
| where | |
| StandardNormal: Distribution<F>, | |
| F: SampleUniform, | |
| { | |
| fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> F { | |
| // Sample uniformly between the two CDFs | |
| let u = rng.random_range(self.detail.clone()); | |
| // Transform back via inverse CDF (quantile) | |
| let z = normal_quantile_standard(u); | |
| self.inner.mean() + self.inner.std_dev() * z | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment