Skip to content

Instantly share code, notes, and snippets.

@jakobrs
Last active July 15, 2023 13:50
Show Gist options
  • Save jakobrs/5b6a246bba2aafd17cce4632af05e70f to your computer and use it in GitHub Desktop.
Save jakobrs/5b6a246bba2aafd17cce4632af05e70f to your computer and use it in GitHub Desktop.
Montgomery arithmetic in Rust
use std::ops::{Add, Mul, Sub};
/// Performs the extended euclidean algorithm
pub const fn xgcd(a: i64, b: i64) -> (i64, i64, i64) {
let (mut old_r, mut r) = (a, b);
let (mut old_s, mut s) = (1, 0);
let (mut old_t, mut t) = (0, 1);
while r != 0 {
let quotient = old_r / r;
(old_r, r) = (r, old_r - quotient * r);
(old_s, s) = (s, old_s - quotient * s);
(old_t, t) = (t, old_t - quotient * t);
}
(old_r, old_s, old_t)
}
/// A number in montgomery form.
/// M is the modulus used and R = 2^P
#[derive(Clone, Copy, Debug)]
pub struct Montgomery<const M: i64, const P: i32>(i32);
impl<const M: i64, const P: i32> Montgomery<M, P> {
/// R2 ≡ R * R (mod M), used in `new`
pub const R2: i64 = ((1 << P) * (1 << P / 2) % M as i64 * (1 << (P - P / 2)) % M as i64);
/// M * M_PRIME ≡ 1 (mod R)
pub const M_PRIME: i64 = -xgcd(1 << P, M).2 % (1 << P);
/// R_MASK is used to get the remainder mod R
pub const R_MASK: i64 = (1 << P) - 1;
/// Zero
pub const ZERO: Self = Self(0);
/// One
pub const ONE: Self = Self::new(1);
/// Performs montgomery reduction
const fn redc(t: i64) -> Self {
let m = ((t & Self::R_MASK) * Self::M_PRIME) & Self::R_MASK;
let t = (t + m * M) >> P;
Self(if t >= M { t - M } else { t } as i32)
}
/// Converts a number to montgomery form
pub const fn new(n: i32) -> Self {
debug_assert!((M * Self::M_PRIME + 1) & Self::R_MASK == 0);
Self::redc(((n as i64).rem_euclid(M)) * Self::R2)
}
/// Converts a number from montgomery form
pub const fn extract(self) -> i32 {
Self::redc(self.0 as i64).0
}
}
macro_rules! impl_montgomery_from {
($ty:ty) => {
impl<const M: i64, const R: i32> From<$ty> for Montgomery<M, R> {
fn from(value: $ty) -> Self {
Self::redc((value as i64).rem_euclid(M) * Self::R2)
}
}
};
}
impl_montgomery_from!(i32);
impl_montgomery_from!(i64);
impl_montgomery_from!(usize);
impl<const M: i64, const P: i32> Add for Montgomery<M, P> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let im = self.0 + rhs.0;
if im >= M as i32 {
Self(im - M as i32)
} else {
Self(im)
}
}
}
impl<const M: i64, const P: i32> Sub for Montgomery<M, P> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
let im = self.0 - rhs.0;
if im < 0 {
Self(im + M as i32)
} else {
Self(im)
}
}
}
impl<const M: i64, const P: i32> Mul for Montgomery<M, P> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self::redc(self.0 as i64 * rhs.0 as i64)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment