Skip to content

Instantly share code, notes, and snippets.

@milkey-mouse
Created October 27, 2024 01:23
Show Gist options
  • Save milkey-mouse/eddd2657d3d8820e0178346ffef7d5ba to your computer and use it in GitHub Desktop.
Save milkey-mouse/eddd2657d3d8820e0178346ffef7d5ba to your computer and use it in GitHub Desktop.
Sample uniformly from (Z/nZ)* to generate deterministic random permutations
#[inline]
pub fn mod_gt_neg_m(a: i64, m: u64) -> u64 {
debug_assert!(
m.try_into().is_ok_and(|m: i64| a >= -m),
"input to mod_gt_neg_m must be >= -m"
);
if a < 0 {
(a + m as i64) as u64
} else {
a as u64
}
}
#[inline]
pub fn mod_lt_2m(a: u64, m: u64) -> u64 {
debug_assert!(a < 2 * m, "input to mod_lt_2m must be less than 2m");
if a >= m {
a - m
} else {
a
}
}
pub fn mod_inv(a: u64, m: u64) -> u64 {
let (mut t, mut new_t) = (0i64, 1i64);
let (mut r, mut new_r) = (m, a);
while new_r != 0 {
let q = r / new_r;
(t, new_t) = (new_t, t - q as i64 * new_t);
(r, new_r) = (new_r, r - q * new_r);
}
debug_assert_eq!(r, 1, "inputs to mod_inv must be coprime");
mod_gt_neg_m(t, m)
}
pub fn crt(r1: u64, m1: u64, r2: u64, m2: u64) -> (u64, u64) {
// both of these should always hold when called by seed_to_unit
debug_assert!(!m1.overflowing_mul(m2).1, "inputs to crt must fit in u64");
debug_assert!(r1 < m1 && r2 < m2, "inputs to crt must be reduced");
let k1 = (r1 * mod_inv(m2, m1)) % m1;
let k2 = (r2 * mod_inv(m1, m2)) % m2;
let m = m1 * m2;
(mod_lt_2m(m2 * k1 + m1 * k2, m), m)
}
pub fn seed_to_unit(seed: &mut u64, n: u64) -> u64 {
let mut rem = n;
let mut accum_result = 0;
let mut accum_modulus = 1;
let mut add_prime_power = |p, p_k_minus_1| {
let chunk = *seed % p_k_minus_1;
*seed /= p_k_minus_1;
let offset = *seed % (p - 1);
*seed /= p - 1;
let unit = (p * chunk) + (1 + offset);
let p_k = p * p_k_minus_1;
(accum_result, accum_modulus) = crt(accum_result, accum_modulus, unit, p_k);
};
let two_power = rem.trailing_zeros();
if two_power > 0 {
rem >>= two_power;
add_prime_power(2, 1 << (two_power - 1));
}
for p in (3..).step_by(2) {
if p * p > rem {
break;
} else if rem % p == 0 {
rem /= p;
if rem % p == 0 {
let mut p_k_minus_1 = p;
loop {
rem /= p;
if rem % p != 0 {
break;
}
p_k_minus_1 *= p;
}
add_prime_power(p, p_k_minus_1);
} else {
add_prime_power(p, 1);
}
}
}
if rem > 1 {
add_prime_power(rem, 1);
}
debug_assert!(accum_result < n);
debug_assert_eq!(accum_modulus, n);
accum_result
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment