Skip to content

Instantly share code, notes, and snippets.

@Caellian
Last active October 18, 2025 10:09
Show Gist options
  • Select an option

  • Save Caellian/98da404e2288eb1522ae007b33dce0bd to your computer and use it in GitHub Desktop.

Select an option

Save Caellian/98da404e2288eb1522ae007b33dce0bd to your computer and use it in GitHub Desktop.
use rand_distr::{Distribution, Normal, StandardNormal, num_traits::Float, uniform::SampleUniform};
trait TruncationDetail {
type Value;
}
impl<F: Float> TruncationDetail for Normal<F>
where
StandardNormal: Distribution<F>,
{
type Value = std::ops::Range<F>;
}
pub struct Truncated<T, D: Distribution<T>>
where
D: TruncationDetail,
{
inner: D,
bounds: std::ops::Range<T>,
detail: <D as TruncationDetail>::Value,
}
pub trait Truncation<T>: Sized + Distribution<T> + TruncationDetail {
fn truncate(self, bounds: std::ops::Range<T>) -> Truncated<T, Self>;
}
impl<F: Float> Truncation<F> for Normal<F>
where
StandardNormal: Distribution<F>,
{
fn truncate(self, bounds: std::ops::Range<F>) -> Truncated<F, Self> {
// Normalize bounds
let alpha = (bounds.start - self.mean()) / self.std_dev();
let beta = (bounds.end - self.mean()) / self.std_dev();
// Compute their CDF values
let phi_a = normal_cdf_standard(alpha);
let phi_b = normal_cdf_standard(beta);
Truncated {
inner: self,
bounds,
detail: phi_a..phi_b,
}
}
}
/// Abramowitz-Stegun approximation of the error function erf(x)
///
/// accurate to ~7 decimal digits
fn erf<F: Float>(x: F) -> F {
let x = x.to_f64().unwrap();
// constants
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
// Save the sign of x
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
// Abramowitz & Stegun formula 7.1.26
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t) * (-x * x).exp();
F::from(sign * y).unwrap()
}
/// Standard normal CDF Φ(x)
fn normal_cdf_standard<F: Float>(x: F) -> F {
F::from(0.5).unwrap() * (F::one() + erf(x / F::from(std::f64::consts::SQRT_2).unwrap()))
}
/// Normal CDF for N(mu, sigma²)
fn normal_cdf<F: Float>(x: F, mu: F, sigma: F) -> F {
normal_cdf_standard((x - mu) / sigma)
}
/// Inverse CDF (quantile) for the standard normal distribution.
///
/// Peter John Acklam’s algorithm (2003)
///
/// Accurate to ~1e-9 for all p in (0,1)
#[allow(clippy::excessive_precision)]
pub fn normal_quantile_standard<F: Float>(p: F) -> F {
let p = p.to_f64().unwrap();
assert!(p > 0.0 && p < 1.0, "p must be in (0, 1)");
// Coefficients in rational approximations
const A: [f64; 6] = [
-3.969683028665376e+01,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383577518672690e+02,
-3.066479806614716e+01,
2.506628277459239e+00,
];
const B: [f64; 5] = [
-5.447609879822406e+01,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
];
const C: [f64; 6] = [
-7.784894002430293e-03,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00,
];
const D: [f64; 4] = [
7.784695709041462e-03,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
];
// Define break-points
const P_LOW: f64 = 0.02425;
const P_HIGH: f64 = 1.0 - P_LOW;
let result = if p < P_LOW {
// Lower tail
let q = (-2.0 * p.ln()).sqrt();
-((((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0))
} else if p <= P_HIGH {
// Central region
let q = p - 0.5;
let r = q * q;
(((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
/ (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
} else {
// Upper tail
let q = (-2.0 * (1.0 - p).ln()).sqrt();
(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
};
F::from(result).unwrap()
}
/// Inverse CDF for a general normal N(mu, sigma²)
pub fn normal_quantile<F: Float>(p: F, mu: F, sigma: F) -> F {
mu + sigma * normal_quantile_standard(p)
}
impl<F: Float> Distribution<F> for Truncated<F, Normal<F>>
where
StandardNormal: Distribution<F>,
F: SampleUniform,
{
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> F {
// Sample uniformly between the two CDFs
let u = rng.random_range(self.detail.clone());
// Transform back via inverse CDF (quantile)
let z = normal_quantile_standard(u);
self.inner.mean() + self.inner.std_dev() * z
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment