Skip to content

Instantly share code, notes, and snippets.

@catid
Created July 29, 2018 09:08
Show Gist options
  • Save catid/d11d5f4462ef4e3428af45f065a175df to your computer and use it in GitHub Desktop.
Save catid/d11d5f4462ef4e3428af45f065a175df to your computer and use it in GitHub Desktop.
Fp=2^61-1 Mersenne prime field - Somewhat tested, somewhat optimized
//------------------------------------------------------------------------------
// Integer Arithmetic Modulo Mersenne Prime 2^61-1
// p = 2^61 - 1
static const uint64_t kFp61Prime = ((uint64_t)1 << 61) - 1;
// x + y (without reduction modulo p)
// For x,y < p this can be done up to 7 times (adding 8 values) without overflow
static inline uint64_t fp61_add(uint64_t x, uint64_t y)
{
return x + y;
}
// x - y (mod p)
static inline uint64_t fp61_sub(uint64_t x, uint64_t y)
{
const uint64_t s = x - y;
// If this would borrow out, then add p (by subtracting 1)
return x < y ? s - 1 : s;
}
/*
Largest x,y = p - 1 = 2^61 - 2 = L.
L*L = (2^61-2) * (2^61-2)
= 2^(61+61) - 4*2^61 + 4
= 2^122 - 2^63 + 4
That is the high 6 bits are zero.
We represent the product as two 64-bit words, or 128 bits.
Say the low bit bit #64 is set in the high word.
To eliminate this bit we need to subtract (2^61 - 1) * 2^3.
This means we need to add a bit at #3.
Similarly for bit #65 we need to add a bit at #4.
High bits #127 to #125 affect high bits #66 to #64.
High bits #124 to #64 affect low bits #63 to #3.
Low bits #63 to #61 affect low bits #2 to #0.
If we eliminate from high bits to low bits, then we could carry back
up into the high bits again. So we should instead eliminate bits #61
through #63 first to prevent carries into the high word.
*/
static const uint64_t kMask63 = ((uint64_t)1 << 63) - 1;
/*
x * y (mod p)
The number of bits between x and y must be less than 124 bits.
The result is stored in bits #63 to #0 (all 64 bits of the word).
Call fp61_reduce_partial() to reduce the value to 61 bits.
Example: If x <= 2^61-1 (61 bits), then y <= 2^63-1 (63 bits).
This means that up to 4 values can be accumulated in y.
Example: If x <= 2^62-1 (62 bits), then y <= 2^62-1 (62 bits).
This means that up to 2 values can be accumulated in x and 2 in y.
*/
static inline uint64_t fp61_mul(uint64_t x, uint64_t y)
{
uint64_t p_lo, p_hi;
CAT_MUL128(p_hi, p_lo, x, y);
// Eliminate bits #63 to #61, which may carry back up into bit #61,
// So we will only definitely reduce #63 and #62.
p_lo = (p_lo & kFp61Prime) + (p_lo >> 61);
// Eliminate bits #123 to #64 (60 bits).
// This stops short of #124 that would affect bit #63 because it
// prevents the addition from overflowing the 64-bit word.
return p_lo + ((p_hi << 3) & kMask63);
}
// Reduce a value (mod p)
static inline uint64_t fp61_reduce(uint64_t x)
{
// Eliminate bits #63 to #61, which may carry back up into bit #61,
// So we will only definitely reduce #63 and #62.
x = (x & kFp61Prime) + (x >> 61);
// Eliminate #61. The +1 also handles the case where x = p.
x = (x + ((x + 1) >> 61)) & kFp61Prime;
return x; // 0 <= x < p
}
static inline uint64_t fp61_inv(uint64_t x)
{
if (x <= 1) {
return x;
}
int64_t s0, s1, r0, r1;
// Compute the next remainder
uint64_t uq = static_cast<uint64_t>(kFp61Prime) / x;
// Store the results
r0 = x;
r1 = kFp61Prime - static_cast<int64_t>(uq * x);
s0 = 1;
s1 = -static_cast<int64_t>(uq);
while (r1 != 1)
{
int64_t q, r, s;
q = r0 / r1;
r = r0 - (q * r1);
s = s0 - (q * s1);
r0 = r1;
r1 = r;
s0 = s1;
s1 = s;
}
if (s1 < 0) {
s1 += kFp61Prime;
}
return static_cast<uint64_t>(s1);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment