Created
August 25, 2021 20:18
-
-
Save vmx/822bcffa78fe207d257e0028bb4357d1 to your computer and use it in GitHub Desktop.
Multiexp and FFT OpenCL/CUDA kernel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Defines to make the code work with both, CUDA and OpenCL | |
//#ifdef __NVCC__ | |
#ifdef __CUDACC__ | |
#define DEVICE __device__ | |
#define GLOBAL | |
#define KERNEL extern "C" __global__ | |
#define LOCAL __shared__ | |
#define GET_GLOBAL_ID() blockIdx.x * blockDim.x + threadIdx.x | |
#define GET_GROUP_ID() blockIdx.x | |
#define GET_LOCAL_ID() threadIdx.x | |
#define GET_LOCAL_SIZE() blockDim.x | |
#define BARRIER_LOCAL() __syncthreads() | |
typedef unsigned char uchar; | |
#define CUDA | |
#else // OpenCL | |
#define DEVICE | |
#define GLOBAL __global | |
#define KERNEL __kernel | |
#define LOCAL __local | |
#define GET_GLOBAL_ID() get_global_id(0) | |
#define GET_GROUP_ID() get_group_id(0) | |
#define GET_LOCAL_ID() get_local_id(0) | |
#define GET_LOCAL_SIZE() get_local_size(0) | |
#define BARRIER_LOCAL() barrier(CLK_LOCAL_MEM_FENCE) | |
#endif | |
typedef uint limb; | |
#define LIMB_BITS (32) | |
// GO ON HERE and revert #define CUDA's back to #deefine OPENCL_NVIDIA where it makes sense | |
#ifdef __NV_CL_C_VERSION | |
#define OPENCL_NVIDIA | |
#endif | |
#if defined(__WinterPark__) || defined(__BeaverCreek__) || defined(__Turks__) || \ | |
defined(__Caicos__) || defined(__Tahiti__) || defined(__Pitcairn__) || \ | |
defined(__Capeverde__) || defined(__Cayman__) || defined(__Barts__) || \ | |
defined(__Cypress__) || defined(__Juniper__) || defined(__Redwood__) || \ | |
defined(__Cedar__) || defined(__ATI_RV770__) || defined(__ATI_RV730__) || \ | |
defined(__ATI_RV710__) || defined(__Loveland__) || defined(__GPU__) || \ | |
defined(__Hawaii__) | |
#define AMD | |
#endif | |
// Returns a * b + c + d, puts the carry in d | |
DEVICE limb mac_with_carry(limb a, limb b, limb c, limb *d) { | |
ulong res = (ulong)a * b + c + *d; | |
*d = res >> 32; | |
return res; | |
} | |
// Returns a + b, puts the carry in b | |
DEVICE limb add_with_carry(limb a, limb *b) { | |
#ifdef OPENCL_NVIDIA | |
limb lo, hi; | |
asm("add.cc.u32 %0, %2, %3;\r\n" | |
"addc.u32 %1, 0, 0;\r\n" | |
: "=r"(lo), "=r"(hi) : "r"(a), "r"(*b)); | |
*b = hi; | |
return lo; | |
#else | |
limb lo = a + *b; | |
*b = lo < a; | |
return lo; | |
#endif | |
} | |
#ifdef CUDA | |
// Code based on the work from Supranational, with special thanks to Niall Emmart: | |
// | |
// We would like to acknowledge Niall Emmart at Nvidia for his significant | |
// contribution of concepts and code for generating efficient SASS on | |
// Nvidia GPUs. The following papers may be of interest: | |
// Optimizing Modular Multiplication for NVIDIA's Maxwell GPUs | |
// https://ieeexplore.ieee.org/document/7563271 | |
// | |
// Faster modular exponentiation using double precision floating point | |
// arithmetic on the GPU | |
// https://ieeexplore.ieee.org/document/8464792 | |
typedef uint uint32_t; | |
typedef int int32_t; | |
typedef uint limb; | |
DEVICE inline uint32_t add_cc(uint32_t a, uint32_t b) { | |
uint32_t r; | |
asm volatile ("add.cc.u32 %0, %1, %2;" : "=r"(r) : "r"(a), "r"(b)); | |
return r; | |
} | |
DEVICE inline uint32_t addc_cc(uint32_t a, uint32_t b) { | |
uint32_t r; | |
asm volatile ("addc.cc.u32 %0, %1, %2;" : "=r"(r) : "r"(a), "r"(b)); | |
return r; | |
} | |
DEVICE inline uint32_t addc(uint32_t a, uint32_t b) { | |
uint32_t r; | |
asm volatile ("addc.u32 %0, %1, %2;" : "=r"(r) : "r"(a), "r"(b)); | |
return r; | |
} | |
DEVICE inline uint32_t madlo(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("mad.lo.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madlo_cc(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("mad.lo.cc.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madloc_cc(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("madc.lo.cc.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madloc(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("madc.lo.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madhi(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("mad.hi.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madhi_cc(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("mad.hi.cc.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madhic_cc(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("madc.hi.cc.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
DEVICE inline uint32_t madhic(uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
asm volatile ("madc.hi.u32 %0, %1, %2, %3;" : "=r"(r) : "r"(a), "r"(b), "r"(c)); | |
return r; | |
} | |
typedef struct { | |
int32_t _position; | |
} chain_t; | |
DEVICE inline | |
void chain_init(chain_t *c) { | |
c->_position = 0; | |
} | |
DEVICE inline | |
uint32_t chain_add(chain_t *ch, uint32_t a, uint32_t b) { | |
uint32_t r; | |
ch->_position++; | |
if(ch->_position==1) | |
r=add_cc(a, b); | |
else | |
r=addc_cc(a, b); | |
return r; | |
} | |
DEVICE inline | |
uint32_t chain_madlo(chain_t *ch, uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
ch->_position++; | |
if(ch->_position==1) | |
r=madlo_cc(a, b, c); | |
else | |
r=madloc_cc(a, b, c); | |
return r; | |
} | |
DEVICE inline | |
uint32_t chain_madhi(chain_t *ch, uint32_t a, uint32_t b, uint32_t c) { | |
uint32_t r; | |
ch->_position++; | |
if(ch->_position==1) | |
r=madhi_cc(a, b, c); | |
else | |
r=madhic_cc(a, b, c); | |
return r; | |
} | |
#endif | |
#define Fr_LIMBS 8 | |
#define Fr_ONE ((Fr){ { 4294967294, 1, 215042, 1485092858, 3971764213, 2576109551, 2898593135, 405057881 } }) | |
#define Fr_P ((Fr){ { 1, 4294967295, 4294859774, 1404937218, 161601541, 859428872, 698187080, 1944954707 } }) | |
#define Fr_R2 ((Fr){ { 4092763245, 3382307216, 2274516003, 728559051, 1918122383, 97719446, 2673475345, 122214873 } }) | |
#define Fr_ZERO ((Fr){ { 0, 0, 0, 0, 0, 0, 0, 0 } }) | |
#define Fr_INV 4294967295 | |
typedef struct { limb val[Fr_LIMBS]; } Fr; | |
#ifdef CUDA | |
DEVICE Fr Fr_sub_nvidia(Fr a, Fr b) { | |
asm("sub.cc.u32 %0, %0, %8;\r\n" | |
"subc.cc.u32 %1, %1, %9;\r\n" | |
"subc.cc.u32 %2, %2, %10;\r\n" | |
"subc.cc.u32 %3, %3, %11;\r\n" | |
"subc.cc.u32 %4, %4, %12;\r\n" | |
"subc.cc.u32 %5, %5, %13;\r\n" | |
"subc.cc.u32 %6, %6, %14;\r\n" | |
"subc.u32 %7, %7, %15;\r\n" | |
:"+r"(a.val[0]), "+r"(a.val[1]), "+r"(a.val[2]), "+r"(a.val[3]), "+r"(a.val[4]), "+r"(a.val[5]), "+r"(a.val[6]), "+r"(a.val[7]) | |
:"r"(b.val[0]), "r"(b.val[1]), "r"(b.val[2]), "r"(b.val[3]), "r"(b.val[4]), "r"(b.val[5]), "r"(b.val[6]), "r"(b.val[7])); | |
return a; | |
} | |
DEVICE Fr Fr_add_nvidia(Fr a, Fr b) { | |
asm("add.cc.u32 %0, %0, %8;\r\n" | |
"addc.cc.u32 %1, %1, %9;\r\n" | |
"addc.cc.u32 %2, %2, %10;\r\n" | |
"addc.cc.u32 %3, %3, %11;\r\n" | |
"addc.cc.u32 %4, %4, %12;\r\n" | |
"addc.cc.u32 %5, %5, %13;\r\n" | |
"addc.cc.u32 %6, %6, %14;\r\n" | |
"addc.u32 %7, %7, %15;\r\n" | |
:"+r"(a.val[0]), "+r"(a.val[1]), "+r"(a.val[2]), "+r"(a.val[3]), "+r"(a.val[4]), "+r"(a.val[5]), "+r"(a.val[6]), "+r"(a.val[7]) | |
:"r"(b.val[0]), "r"(b.val[1]), "r"(b.val[2]), "r"(b.val[3]), "r"(b.val[4]), "r"(b.val[5]), "r"(b.val[6]), "r"(b.val[7])); | |
return a; | |
} | |
#endif | |
// FinalityLabs - 2019 | |
// Arbitrary size prime-field arithmetic library (add, sub, mul, pow) | |
#define Fr_BITS (Fr_LIMBS * LIMB_BITS) | |
// Greater than or equal | |
DEVICE bool Fr_gte(Fr a, Fr b) { | |
for(char i = Fr_LIMBS - 1; i >= 0; i--){ | |
if(a.val[i] > b.val[i]) | |
return true; | |
if(a.val[i] < b.val[i]) | |
return false; | |
} | |
return true; | |
} | |
// Equals | |
DEVICE bool Fr_eq(Fr a, Fr b) { | |
for(uchar i = 0; i < Fr_LIMBS; i++) | |
if(a.val[i] != b.val[i]) | |
return false; | |
return true; | |
} | |
// Normal addition | |
#ifdef CUDA | |
#define Fr_add_ Fr_add_nvidia | |
#define Fr_sub_ Fr_sub_nvidia | |
#else | |
DEVICE Fr Fr_add_(Fr a, Fr b) { | |
bool carry = 0; | |
for(uchar i = 0; i < Fr_LIMBS; i++) { | |
limb old = a.val[i]; | |
a.val[i] += b.val[i] + carry; | |
carry = carry ? old >= a.val[i] : old > a.val[i]; | |
} | |
return a; | |
} | |
DEVICE Fr Fr_sub_(Fr a, Fr b) { | |
bool borrow = 0; | |
for(uchar i = 0; i < Fr_LIMBS; i++) { | |
limb old = a.val[i]; | |
a.val[i] -= b.val[i] + borrow; | |
borrow = borrow ? old <= a.val[i] : old < a.val[i]; | |
} | |
return a; | |
} | |
#endif | |
// Modular subtraction | |
DEVICE Fr Fr_sub(Fr a, Fr b) { | |
Fr res = Fr_sub_(a, b); | |
if(!Fr_gte(a, b)) res = Fr_add_(res, Fr_P); | |
return res; | |
} | |
// Modular addition | |
DEVICE Fr Fr_add(Fr a, Fr b) { | |
Fr res = Fr_add_(a, b); | |
if(Fr_gte(res, Fr_P)) res = Fr_sub_(res, Fr_P); | |
return res; | |
} | |
#ifdef CUDA | |
DEVICE void Fr_reduce(uint32_t accLow[Fr_LIMBS], uint32_t np0, uint32_t fq[Fr_LIMBS]) { | |
// accLow is an IN and OUT vector | |
// count must be even | |
const uint32_t count = Fr_LIMBS; | |
uint32_t accHigh[Fr_LIMBS]; | |
uint32_t bucket=0, lowCarry=0, highCarry=0, q; | |
int32_t i, j; | |
#pragma unroll | |
for(i=0;i<count;i++) | |
accHigh[i]=0; | |
// bucket is used so we don't have to push a carry all the way down the line | |
#pragma unroll | |
for(j=0;j<count;j++) { // main iteration | |
if(j%2==0) { | |
add_cc(bucket, 0xFFFFFFFF); | |
accLow[0]=addc_cc(accLow[0], accHigh[1]); | |
bucket=addc(0, 0); | |
q=accLow[0]*np0; | |
chain_t chain1; | |
chain_init(&chain1); | |
#pragma unroll | |
for(i=0;i<count;i+=2) { | |
accLow[i]=chain_madlo(&chain1, q, fq[i], accLow[i]); | |
accLow[i+1]=chain_madhi(&chain1, q, fq[i], accLow[i+1]); | |
} | |
lowCarry=chain_add(&chain1, 0, 0); | |
chain_t chain2; | |
chain_init(&chain2); | |
for(i=0;i<count-2;i+=2) { | |
accHigh[i]=chain_madlo(&chain2, q, fq[i+1], accHigh[i+2]); // note the shift down | |
accHigh[i+1]=chain_madhi(&chain2, q, fq[i+1], accHigh[i+3]); | |
} | |
accHigh[i]=chain_madlo(&chain2, q, fq[i+1], highCarry); | |
accHigh[i+1]=chain_madhi(&chain2, q, fq[i+1], 0); | |
} | |
else { | |
add_cc(bucket, 0xFFFFFFFF); | |
accHigh[0]=addc_cc(accHigh[0], accLow[1]); | |
bucket=addc(0, 0); | |
q=accHigh[0]*np0; | |
chain_t chain3; | |
chain_init(&chain3); | |
#pragma unroll | |
for(i=0;i<count;i+=2) { | |
accHigh[i]=chain_madlo(&chain3, q, fq[i], accHigh[i]); | |
accHigh[i+1]=chain_madhi(&chain3, q, fq[i], accHigh[i+1]); | |
} | |
highCarry=chain_add(&chain3, 0, 0); | |
chain_t chain4; | |
chain_init(&chain4); | |
for(i=0;i<count-2;i+=2) { | |
accLow[i]=chain_madlo(&chain4, q, fq[i+1], accLow[i+2]); // note the shift down | |
accLow[i+1]=chain_madhi(&chain4, q, fq[i+1], accLow[i+3]); | |
} | |
accLow[i]=chain_madlo(&chain4, q, fq[i+1], lowCarry); | |
accLow[i+1]=chain_madhi(&chain4, q, fq[i+1], 0); | |
} | |
} | |
// at this point, accHigh needs to be shifted back a word and added to accLow | |
// we'll use one other trick. Bucket is either 0 or 1 at this point, so we | |
// can just push it into the carry chain. | |
chain_t chain5; | |
chain_init(&chain5); | |
chain_add(&chain5, bucket, 0xFFFFFFFF); // push the carry into the chain | |
#pragma unroll | |
for(i=0;i<count-1;i++) | |
accLow[i]=chain_add(&chain5, accLow[i], accHigh[i+1]); | |
accLow[i]=chain_add(&chain5, accLow[i], highCarry); | |
} | |
// Requirement: yLimbs >= xLimbs | |
DEVICE inline | |
void Fr_mult_v1(uint32_t *x, uint32_t *y, uint32_t *xy) { | |
const uint32_t xLimbs = Fr_LIMBS; | |
const uint32_t yLimbs = Fr_LIMBS; | |
const uint32_t xyLimbs = Fr_LIMBS * 2; | |
uint32_t temp[Fr_LIMBS * 2]; | |
uint32_t carry = 0; | |
#pragma unroll | |
for (int32_t i = 0; i < xyLimbs; i++) { | |
temp[i] = 0; | |
} | |
#pragma unroll | |
for (int32_t i = 0; i < xLimbs; i++) { | |
chain_t chain1; | |
chain_init(&chain1); | |
#pragma unroll | |
for (int32_t j = 0; j < yLimbs; j++) { | |
if ((i + j) % 2 == 1) { | |
temp[i + j - 1] = chain_madlo(&chain1, x[i], y[j], temp[i + j - 1]); | |
temp[i + j] = chain_madhi(&chain1, x[i], y[j], temp[i + j]); | |
} | |
} | |
if (i % 2 == 1) { | |
temp[i + yLimbs - 1] = chain_add(&chain1, 0, 0); | |
} | |
} | |
#pragma unroll | |
for (int32_t i = xyLimbs - 1; i > 0; i--) { | |
temp[i] = temp[i - 1]; | |
} | |
temp[0] = 0; | |
#pragma unroll | |
for (int32_t i = 0; i < xLimbs; i++) { | |
chain_t chain2; | |
chain_init(&chain2); | |
#pragma unroll | |
for (int32_t j = 0; j < yLimbs; j++) { | |
if ((i + j) % 2 == 0) { | |
temp[i + j] = chain_madlo(&chain2, x[i], y[j], temp[i + j]); | |
temp[i + j + 1] = chain_madhi(&chain2, x[i], y[j], temp[i + j + 1]); | |
} | |
} | |
if ((i + yLimbs) % 2 == 0 && i != yLimbs - 1) { | |
temp[i + yLimbs] = chain_add(&chain2, temp[i + yLimbs], carry); | |
temp[i + yLimbs + 1] = chain_add(&chain2, temp[i + yLimbs + 1], 0); | |
carry = chain_add(&chain2, 0, 0); | |
} | |
if ((i + yLimbs) % 2 == 1 && i != yLimbs - 1) { | |
carry = chain_add(&chain2, carry, 0); | |
} | |
} | |
#pragma unroll | |
for(int32_t i = 0; i < xyLimbs; i++) { | |
xy[i] = temp[i]; | |
} | |
} | |
DEVICE Fr Fr_mul_nvidia(Fr a, Fr b) { | |
// Perform full multiply | |
// TODO: should not be +1, then addition below should stop one word earlier! | |
limb ab[2 * Fr_LIMBS + 1]; | |
Fr_mult_v1(a.val, b.val, ab); | |
uint32_t io[Fr_LIMBS]; | |
#pragma unroll | |
for(int i=0;i<Fr_LIMBS;i++) { | |
io[i]=ab[i]; | |
} | |
Fr_reduce(io, Fr_INV, Fr_P.val); | |
// Add io to the upper words of ab | |
ab[Fr_LIMBS] = add_cc(ab[Fr_LIMBS], io[0]); | |
int j; | |
#pragma unroll | |
for (j = 1; j < Fr_LIMBS; j++) { | |
ab[j + Fr_LIMBS] = addc_cc(ab[j + Fr_LIMBS], io[j]); | |
} | |
ab[2 * Fr_LIMBS] = addc(ab[2 * Fr_LIMBS], io[Fr_LIMBS]); | |
Fr r; | |
#pragma unroll | |
for (int i = 0; i < Fr_LIMBS; i++) { | |
r.val[i] = ab[i + Fr_LIMBS]; | |
} | |
if (Fr_gte(r, Fr_P)) { | |
r = Fr_sub_(r, Fr_P); | |
} | |
return r; | |
} | |
#endif | |
// Modular multiplication | |
DEVICE Fr Fr_mul_default(Fr a, Fr b) { | |
/* CIOS Montgomery multiplication, inspired from Tolga Acar's thesis: | |
* https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf | |
* Learn more: | |
* https://en.wikipedia.org/wiki/Montgomery_modular_multiplication | |
* https://alicebob.cryptoland.net/understanding-the-montgomery-reduction-algorithm/ | |
*/ | |
limb t[Fr_LIMBS + 2] = {0}; | |
for(uchar i = 0; i < Fr_LIMBS; i++) { | |
limb carry = 0; | |
for(uchar j = 0; j < Fr_LIMBS; j++) | |
t[j] = mac_with_carry(a.val[j], b.val[i], t[j], &carry); | |
t[Fr_LIMBS] = add_with_carry(t[Fr_LIMBS], &carry); | |
t[Fr_LIMBS + 1] = carry; | |
carry = 0; | |
limb m = Fr_INV * t[0]; | |
mac_with_carry(m, Fr_P.val[0], t[0], &carry); | |
for(uchar j = 1; j < Fr_LIMBS; j++) | |
t[j - 1] = mac_with_carry(m, Fr_P.val[j], t[j], &carry); | |
t[Fr_LIMBS - 1] = add_with_carry(t[Fr_LIMBS], &carry); | |
t[Fr_LIMBS] = t[Fr_LIMBS + 1] + carry; | |
} | |
Fr result; | |
for(uchar i = 0; i < Fr_LIMBS; i++) result.val[i] = t[i]; | |
if(Fr_gte(result, Fr_P)) result = Fr_sub_(result, Fr_P); | |
return result; | |
} | |
#ifdef CUDA | |
DEVICE Fr Fr_mul(Fr a, Fr b) { | |
return Fr_mul_nvidia(a, b); | |
} | |
#else | |
DEVICE Fr Fr_mul(Fr a, Fr b) { | |
return Fr_mul_default(a, b); | |
} | |
#endif | |
// Squaring is a special case of multiplication which can be done ~1.5x faster. | |
// https://stackoverflow.com/a/16388571/1348497 | |
DEVICE Fr Fr_sqr(Fr a) { | |
return Fr_mul(a, a); | |
} | |
// Left-shift the limbs by one bit and subtract by modulus in case of overflow. | |
// Faster version of Fr_add(a, a) | |
DEVICE Fr Fr_double(Fr a) { | |
for(uchar i = Fr_LIMBS - 1; i >= 1; i--) | |
a.val[i] = (a.val[i] << 1) | (a.val[i - 1] >> (LIMB_BITS - 1)); | |
a.val[0] <<= 1; | |
if(Fr_gte(a, Fr_P)) a = Fr_sub_(a, Fr_P); | |
return a; | |
} | |
// Modular exponentiation (Exponentiation by Squaring) | |
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring | |
DEVICE Fr Fr_pow(Fr base, uint exponent) { | |
Fr res = Fr_ONE; | |
while(exponent > 0) { | |
if (exponent & 1) | |
res = Fr_mul(res, base); | |
exponent = exponent >> 1; | |
base = Fr_sqr(base); | |
} | |
return res; | |
} | |
// Store squares of the base in a lookup table for faster evaluation. | |
DEVICE Fr Fr_pow_lookup(GLOBAL Fr *bases, uint exponent) { | |
Fr res = Fr_ONE; | |
uint i = 0; | |
while(exponent > 0) { | |
if (exponent & 1) | |
res = Fr_mul(res, bases[i]); | |
exponent = exponent >> 1; | |
i++; | |
} | |
return res; | |
} | |
DEVICE Fr Fr_mont(Fr a) { | |
return Fr_mul(a, Fr_R2); | |
} | |
DEVICE Fr Fr_unmont(Fr a) { | |
Fr one = Fr_ZERO; | |
one.val[0] = 1; | |
return Fr_mul(a, one); | |
} | |
// Get `i`th bit (From most significant digit) of the field. | |
DEVICE bool Fr_get_bit(Fr l, uint i) { | |
return (l.val[Fr_LIMBS - 1 - i / LIMB_BITS] >> (LIMB_BITS - 1 - (i % LIMB_BITS))) & 1; | |
} | |
// Get `window` consecutive bits, (Starting from `skip`th bit) from the field. | |
DEVICE uint Fr_get_bits(Fr l, uint skip, uint window) { | |
uint ret = 0; | |
for(uint i = 0; i < window; i++) { | |
ret <<= 1; | |
ret |= Fr_get_bit(l, skip + i); | |
} | |
return ret; | |
} | |
DEVICE void Fr_print(Fr a) { | |
// printf("0x"); | |
// for (uint i = 0; i < Fr_LIMBS; i++) { | |
// printf("%08x", a.val[Fr_LIMBS - i - 1]); | |
// } | |
} | |
DEVICE uint bitreverse(uint n, uint bits) { | |
uint r = 0; | |
for(int i = 0; i < bits; i++) { | |
r = (r << 1) | (n & 1); | |
n >>= 1; | |
} | |
return r; | |
} | |
/* | |
* FFT algorithm is inspired from: http://www.bealto.com/gpu-fft_group-1.html | |
*/ | |
KERNEL void radix_fft(GLOBAL Fr* x, // Source buffer | |
GLOBAL Fr* y, // Destination buffer | |
GLOBAL Fr* pq, // Precalculated twiddle factors | |
GLOBAL Fr* omegas, // [omega, omega^2, omega^4, ...] | |
LOCAL Fr* u, // Local buffer to store intermediary values | |
uint n, // Number of elements | |
uint lgp, // Log2 of `p` (Read more in the link above) | |
uint deg, // 1=>radix2, 2=>radix4, 3=>radix8, ... | |
uint max_deg) // Maximum degree supported, according to `pq` and `omegas` | |
{ | |
uint lid = GET_LOCAL_ID(); | |
uint lsize = GET_LOCAL_SIZE(); | |
uint index = GET_GROUP_ID(); | |
uint t = n >> deg; | |
uint p = 1 << lgp; | |
uint k = index & (p - 1); | |
x += index; | |
y += ((index - k) << deg) + k; | |
uint count = 1 << deg; // 2^deg | |
uint counth = count >> 1; // Half of count | |
uint counts = count / lsize * lid; | |
uint counte = counts + count / lsize; | |
// Compute powers of twiddle | |
const Fr twiddle = Fr_pow_lookup(omegas, (n >> lgp >> deg) * k); | |
Fr tmp = Fr_pow(twiddle, counts); | |
for(uint i = counts; i < counte; i++) { | |
u[i] = Fr_mul(tmp, x[i*t]); | |
tmp = Fr_mul(tmp, twiddle); | |
} | |
BARRIER_LOCAL(); | |
const uint pqshift = max_deg - deg; | |
for(uint rnd = 0; rnd < deg; rnd++) { | |
const uint bit = counth >> rnd; | |
for(uint i = counts >> 1; i < counte >> 1; i++) { | |
const uint di = i & (bit - 1); | |
const uint i0 = (i << 1) - di; | |
const uint i1 = i0 + bit; | |
tmp = u[i0]; | |
u[i0] = Fr_add(u[i0], u[i1]); | |
u[i1] = Fr_sub(tmp, u[i1]); | |
if(di != 0) u[i1] = Fr_mul(pq[di << rnd << pqshift], u[i1]); | |
} | |
BARRIER_LOCAL(); | |
} | |
for(uint i = counts >> 1; i < counte >> 1; i++) { | |
y[i*p] = u[bitreverse(i, deg)]; | |
y[(i+counth)*p] = u[bitreverse(i + counth, deg)]; | |
} | |
} | |
/// Multiplies all of the elements by `field` | |
KERNEL void mul_by_field(GLOBAL Fr* elements, | |
uint n, | |
Fr field) { | |
const uint gid = GET_GLOBAL_ID(); | |
elements[gid] = Fr_mul(elements[gid], field); | |
} | |
#define Fq_LIMBS 12 | |
#define Fq_ONE ((Fq){ { 196605, 1980301312, 3289120770, 3958636555, 1405573306, 1598593111, 1884444485, 2010011731, 2723605613, 1543969431, 4202751123, 368467651 } }) | |
#define Fq_P ((Fq){ { 4294945451, 3120496639, 2975072255, 514588670, 4138792484, 1731252896, 4085584575, 1685539716, 1129032919, 1260103606, 964683418, 436277738 } }) | |
#define Fq_R2 ((Fq){ { 473175878, 4108263220, 164693233, 175564454, 1284880085, 2380613484, 2476573632, 1743489193, 3038352685, 2591637125, 2462770090, 295210981 } }) | |
#define Fq_ZERO ((Fq){ { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } }) | |
#define Fq_INV 4294770685 | |
typedef struct { limb val[Fq_LIMBS]; } Fq; | |
#ifdef CUDA | |
DEVICE Fq Fq_sub_nvidia(Fq a, Fq b) { | |
asm("sub.cc.u32 %0, %0, %12;\r\n" | |
"subc.cc.u32 %1, %1, %13;\r\n" | |
"subc.cc.u32 %2, %2, %14;\r\n" | |
"subc.cc.u32 %3, %3, %15;\r\n" | |
"subc.cc.u32 %4, %4, %16;\r\n" | |
"subc.cc.u32 %5, %5, %17;\r\n" | |
"subc.cc.u32 %6, %6, %18;\r\n" | |
"subc.cc.u32 %7, %7, %19;\r\n" | |
"subc.cc.u32 %8, %8, %20;\r\n" | |
"subc.cc.u32 %9, %9, %21;\r\n" | |
"subc.cc.u32 %10, %10, %22;\r\n" | |
"subc.u32 %11, %11, %23;\r\n" | |
:"+r"(a.val[0]), "+r"(a.val[1]), "+r"(a.val[2]), "+r"(a.val[3]), "+r"(a.val[4]), "+r"(a.val[5]), "+r"(a.val[6]), "+r"(a.val[7]), "+r"(a.val[8]), "+r"(a.val[9]), "+r"(a.val[10]), "+r"(a.val[11]) | |
:"r"(b.val[0]), "r"(b.val[1]), "r"(b.val[2]), "r"(b.val[3]), "r"(b.val[4]), "r"(b.val[5]), "r"(b.val[6]), "r"(b.val[7]), "r"(b.val[8]), "r"(b.val[9]), "r"(b.val[10]), "r"(b.val[11])); | |
return a; | |
} | |
DEVICE Fq Fq_add_nvidia(Fq a, Fq b) { | |
asm("add.cc.u32 %0, %0, %12;\r\n" | |
"addc.cc.u32 %1, %1, %13;\r\n" | |
"addc.cc.u32 %2, %2, %14;\r\n" | |
"addc.cc.u32 %3, %3, %15;\r\n" | |
"addc.cc.u32 %4, %4, %16;\r\n" | |
"addc.cc.u32 %5, %5, %17;\r\n" | |
"addc.cc.u32 %6, %6, %18;\r\n" | |
"addc.cc.u32 %7, %7, %19;\r\n" | |
"addc.cc.u32 %8, %8, %20;\r\n" | |
"addc.cc.u32 %9, %9, %21;\r\n" | |
"addc.cc.u32 %10, %10, %22;\r\n" | |
"addc.u32 %11, %11, %23;\r\n" | |
:"+r"(a.val[0]), "+r"(a.val[1]), "+r"(a.val[2]), "+r"(a.val[3]), "+r"(a.val[4]), "+r"(a.val[5]), "+r"(a.val[6]), "+r"(a.val[7]), "+r"(a.val[8]), "+r"(a.val[9]), "+r"(a.val[10]), "+r"(a.val[11]) | |
:"r"(b.val[0]), "r"(b.val[1]), "r"(b.val[2]), "r"(b.val[3]), "r"(b.val[4]), "r"(b.val[5]), "r"(b.val[6]), "r"(b.val[7]), "r"(b.val[8]), "r"(b.val[9]), "r"(b.val[10]), "r"(b.val[11])); | |
return a; | |
} | |
#endif | |
// FinalityLabs - 2019 | |
// Arbitrary size prime-field arithmetic library (add, sub, mul, pow) | |
#define Fq_BITS (Fq_LIMBS * LIMB_BITS) | |
// Greater than or equal | |
DEVICE bool Fq_gte(Fq a, Fq b) { | |
for(char i = Fq_LIMBS - 1; i >= 0; i--){ | |
if(a.val[i] > b.val[i]) | |
return true; | |
if(a.val[i] < b.val[i]) | |
return false; | |
} | |
return true; | |
} | |
// Equals | |
DEVICE bool Fq_eq(Fq a, Fq b) { | |
for(uchar i = 0; i < Fq_LIMBS; i++) | |
if(a.val[i] != b.val[i]) | |
return false; | |
return true; | |
} | |
// Normal addition | |
#ifdef CUDA | |
#define Fq_add_ Fq_add_nvidia | |
#define Fq_sub_ Fq_sub_nvidia | |
#else | |
DEVICE Fq Fq_add_(Fq a, Fq b) { | |
bool carry = 0; | |
for(uchar i = 0; i < Fq_LIMBS; i++) { | |
limb old = a.val[i]; | |
a.val[i] += b.val[i] + carry; | |
carry = carry ? old >= a.val[i] : old > a.val[i]; | |
} | |
return a; | |
} | |
DEVICE Fq Fq_sub_(Fq a, Fq b) { | |
bool borrow = 0; | |
for(uchar i = 0; i < Fq_LIMBS; i++) { | |
limb old = a.val[i]; | |
a.val[i] -= b.val[i] + borrow; | |
borrow = borrow ? old <= a.val[i] : old < a.val[i]; | |
} | |
return a; | |
} | |
#endif | |
// Modular subtraction | |
DEVICE Fq Fq_sub(Fq a, Fq b) { | |
Fq res = Fq_sub_(a, b); | |
if(!Fq_gte(a, b)) res = Fq_add_(res, Fq_P); | |
return res; | |
} | |
// Modular addition | |
DEVICE Fq Fq_add(Fq a, Fq b) { | |
Fq res = Fq_add_(a, b); | |
if(Fq_gte(res, Fq_P)) res = Fq_sub_(res, Fq_P); | |
return res; | |
} | |
#ifdef CUDA | |
DEVICE void Fq_reduce(uint32_t accLow[Fq_LIMBS], uint32_t np0, uint32_t fq[Fq_LIMBS]) { | |
// accLow is an IN and OUT vector | |
// count must be even | |
const uint32_t count = Fq_LIMBS; | |
uint32_t accHigh[Fq_LIMBS]; | |
uint32_t bucket=0, lowCarry=0, highCarry=0, q; | |
int32_t i, j; | |
#pragma unroll | |
for(i=0;i<count;i++) | |
accHigh[i]=0; | |
// bucket is used so we don't have to push a carry all the way down the line | |
#pragma unroll | |
for(j=0;j<count;j++) { // main iteration | |
if(j%2==0) { | |
add_cc(bucket, 0xFFFFFFFF); | |
accLow[0]=addc_cc(accLow[0], accHigh[1]); | |
bucket=addc(0, 0); | |
q=accLow[0]*np0; | |
chain_t chain1; | |
chain_init(&chain1); | |
#pragma unroll | |
for(i=0;i<count;i+=2) { | |
accLow[i]=chain_madlo(&chain1, q, fq[i], accLow[i]); | |
accLow[i+1]=chain_madhi(&chain1, q, fq[i], accLow[i+1]); | |
} | |
lowCarry=chain_add(&chain1, 0, 0); | |
chain_t chain2; | |
chain_init(&chain2); | |
for(i=0;i<count-2;i+=2) { | |
accHigh[i]=chain_madlo(&chain2, q, fq[i+1], accHigh[i+2]); // note the shift down | |
accHigh[i+1]=chain_madhi(&chain2, q, fq[i+1], accHigh[i+3]); | |
} | |
accHigh[i]=chain_madlo(&chain2, q, fq[i+1], highCarry); | |
accHigh[i+1]=chain_madhi(&chain2, q, fq[i+1], 0); | |
} | |
else { | |
add_cc(bucket, 0xFFFFFFFF); | |
accHigh[0]=addc_cc(accHigh[0], accLow[1]); | |
bucket=addc(0, 0); | |
q=accHigh[0]*np0; | |
chain_t chain3; | |
chain_init(&chain3); | |
#pragma unroll | |
for(i=0;i<count;i+=2) { | |
accHigh[i]=chain_madlo(&chain3, q, fq[i], accHigh[i]); | |
accHigh[i+1]=chain_madhi(&chain3, q, fq[i], accHigh[i+1]); | |
} | |
highCarry=chain_add(&chain3, 0, 0); | |
chain_t chain4; | |
chain_init(&chain4); | |
for(i=0;i<count-2;i+=2) { | |
accLow[i]=chain_madlo(&chain4, q, fq[i+1], accLow[i+2]); // note the shift down | |
accLow[i+1]=chain_madhi(&chain4, q, fq[i+1], accLow[i+3]); | |
} | |
accLow[i]=chain_madlo(&chain4, q, fq[i+1], lowCarry); | |
accLow[i+1]=chain_madhi(&chain4, q, fq[i+1], 0); | |
} | |
} | |
// at this point, accHigh needs to be shifted back a word and added to accLow | |
// we'll use one other trick. Bucket is either 0 or 1 at this point, so we | |
// can just push it into the carry chain. | |
chain_t chain5; | |
chain_init(&chain5); | |
chain_add(&chain5, bucket, 0xFFFFFFFF); // push the carry into the chain | |
#pragma unroll | |
for(i=0;i<count-1;i++) | |
accLow[i]=chain_add(&chain5, accLow[i], accHigh[i+1]); | |
accLow[i]=chain_add(&chain5, accLow[i], highCarry); | |
} | |
// Requirement: yLimbs >= xLimbs | |
DEVICE inline | |
void Fq_mult_v1(uint32_t *x, uint32_t *y, uint32_t *xy) { | |
const uint32_t xLimbs = Fq_LIMBS; | |
const uint32_t yLimbs = Fq_LIMBS; | |
const uint32_t xyLimbs = Fq_LIMBS * 2; | |
uint32_t temp[Fq_LIMBS * 2]; | |
uint32_t carry = 0; | |
#pragma unroll | |
for (int32_t i = 0; i < xyLimbs; i++) { | |
temp[i] = 0; | |
} | |
#pragma unroll | |
for (int32_t i = 0; i < xLimbs; i++) { | |
chain_t chain1; | |
chain_init(&chain1); | |
#pragma unroll | |
for (int32_t j = 0; j < yLimbs; j++) { | |
if ((i + j) % 2 == 1) { | |
temp[i + j - 1] = chain_madlo(&chain1, x[i], y[j], temp[i + j - 1]); | |
temp[i + j] = chain_madhi(&chain1, x[i], y[j], temp[i + j]); | |
} | |
} | |
if (i % 2 == 1) { | |
temp[i + yLimbs - 1] = chain_add(&chain1, 0, 0); | |
} | |
} | |
#pragma unroll | |
for (int32_t i = xyLimbs - 1; i > 0; i--) { | |
temp[i] = temp[i - 1]; | |
} | |
temp[0] = 0; | |
#pragma unroll | |
for (int32_t i = 0; i < xLimbs; i++) { | |
chain_t chain2; | |
chain_init(&chain2); | |
#pragma unroll | |
for (int32_t j = 0; j < yLimbs; j++) { | |
if ((i + j) % 2 == 0) { | |
temp[i + j] = chain_madlo(&chain2, x[i], y[j], temp[i + j]); | |
temp[i + j + 1] = chain_madhi(&chain2, x[i], y[j], temp[i + j + 1]); | |
} | |
} | |
if ((i + yLimbs) % 2 == 0 && i != yLimbs - 1) { | |
temp[i + yLimbs] = chain_add(&chain2, temp[i + yLimbs], carry); | |
temp[i + yLimbs + 1] = chain_add(&chain2, temp[i + yLimbs + 1], 0); | |
carry = chain_add(&chain2, 0, 0); | |
} | |
if ((i + yLimbs) % 2 == 1 && i != yLimbs - 1) { | |
carry = chain_add(&chain2, carry, 0); | |
} | |
} | |
#pragma unroll | |
for(int32_t i = 0; i < xyLimbs; i++) { | |
xy[i] = temp[i]; | |
} | |
} | |
DEVICE Fq Fq_mul_nvidia(Fq a, Fq b) { | |
// Perform full multiply | |
// TODO: should not be +1, then addition below should stop one word earlier! | |
limb ab[2 * Fq_LIMBS + 1]; | |
Fq_mult_v1(a.val, b.val, ab); | |
uint32_t io[Fq_LIMBS]; | |
#pragma unroll | |
for(int i=0;i<Fq_LIMBS;i++) { | |
io[i]=ab[i]; | |
} | |
Fq_reduce(io, Fq_INV, Fq_P.val); | |
// Add io to the upper words of ab | |
ab[Fq_LIMBS] = add_cc(ab[Fq_LIMBS], io[0]); | |
int j; | |
#pragma unroll | |
for (j = 1; j < Fq_LIMBS; j++) { | |
ab[j + Fq_LIMBS] = addc_cc(ab[j + Fq_LIMBS], io[j]); | |
} | |
ab[2 * Fq_LIMBS] = addc(ab[2 * Fq_LIMBS], io[Fq_LIMBS]); | |
Fq r; | |
#pragma unroll | |
for (int i = 0; i < Fq_LIMBS; i++) { | |
r.val[i] = ab[i + Fq_LIMBS]; | |
} | |
if (Fq_gte(r, Fq_P)) { | |
r = Fq_sub_(r, Fq_P); | |
} | |
return r; | |
} | |
#endif | |
// Modular multiplication | |
DEVICE Fq Fq_mul_default(Fq a, Fq b) { | |
/* CIOS Montgomery multiplication, inspired from Tolga Acar's thesis: | |
* https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf | |
* Learn more: | |
* https://en.wikipedia.org/wiki/Montgomery_modular_multiplication | |
* https://alicebob.cryptoland.net/understanding-the-montgomery-reduction-algorithm/ | |
*/ | |
limb t[Fq_LIMBS + 2] = {0}; | |
for(uchar i = 0; i < Fq_LIMBS; i++) { | |
limb carry = 0; | |
for(uchar j = 0; j < Fq_LIMBS; j++) | |
t[j] = mac_with_carry(a.val[j], b.val[i], t[j], &carry); | |
t[Fq_LIMBS] = add_with_carry(t[Fq_LIMBS], &carry); | |
t[Fq_LIMBS + 1] = carry; | |
carry = 0; | |
limb m = Fq_INV * t[0]; | |
mac_with_carry(m, Fq_P.val[0], t[0], &carry); | |
for(uchar j = 1; j < Fq_LIMBS; j++) | |
t[j - 1] = mac_with_carry(m, Fq_P.val[j], t[j], &carry); | |
t[Fq_LIMBS - 1] = add_with_carry(t[Fq_LIMBS], &carry); | |
t[Fq_LIMBS] = t[Fq_LIMBS + 1] + carry; | |
} | |
Fq result; | |
for(uchar i = 0; i < Fq_LIMBS; i++) result.val[i] = t[i]; | |
if(Fq_gte(result, Fq_P)) result = Fq_sub_(result, Fq_P); | |
return result; | |
} | |
#ifdef CUDA | |
DEVICE Fq Fq_mul(Fq a, Fq b) { | |
return Fq_mul_nvidia(a, b); | |
} | |
#else | |
DEVICE Fq Fq_mul(Fq a, Fq b) { | |
return Fq_mul_default(a, b); | |
} | |
#endif | |
// Squaring is a special case of multiplication which can be done ~1.5x faster. | |
// https://stackoverflow.com/a/16388571/1348497 | |
DEVICE Fq Fq_sqr(Fq a) { | |
return Fq_mul(a, a); | |
} | |
// Left-shift the limbs by one bit and subtract by modulus in case of overflow. | |
// Faster version of Fq_add(a, a) | |
DEVICE Fq Fq_double(Fq a) { | |
for(uchar i = Fq_LIMBS - 1; i >= 1; i--) | |
a.val[i] = (a.val[i] << 1) | (a.val[i - 1] >> (LIMB_BITS - 1)); | |
a.val[0] <<= 1; | |
if(Fq_gte(a, Fq_P)) a = Fq_sub_(a, Fq_P); | |
return a; | |
} | |
// Modular exponentiation (Exponentiation by Squaring) | |
// https://en.wikipedia.org/wiki/Exponentiation_by_squaring | |
DEVICE Fq Fq_pow(Fq base, uint exponent) { | |
Fq res = Fq_ONE; | |
while(exponent > 0) { | |
if (exponent & 1) | |
res = Fq_mul(res, base); | |
exponent = exponent >> 1; | |
base = Fq_sqr(base); | |
} | |
return res; | |
} | |
// Store squares of the base in a lookup table for faster evaluation. | |
DEVICE Fq Fq_pow_lookup(GLOBAL Fq *bases, uint exponent) { | |
Fq res = Fq_ONE; | |
uint i = 0; | |
while(exponent > 0) { | |
if (exponent & 1) | |
res = Fq_mul(res, bases[i]); | |
exponent = exponent >> 1; | |
i++; | |
} | |
return res; | |
} | |
DEVICE Fq Fq_mont(Fq a) { | |
return Fq_mul(a, Fq_R2); | |
} | |
DEVICE Fq Fq_unmont(Fq a) { | |
Fq one = Fq_ZERO; | |
one.val[0] = 1; | |
return Fq_mul(a, one); | |
} | |
// Get `i`th bit (From most significant digit) of the field. | |
DEVICE bool Fq_get_bit(Fq l, uint i) { | |
return (l.val[Fq_LIMBS - 1 - i / LIMB_BITS] >> (LIMB_BITS - 1 - (i % LIMB_BITS))) & 1; | |
} | |
// Get `window` consecutive bits, (Starting from `skip`th bit) from the field. | |
DEVICE uint Fq_get_bits(Fq l, uint skip, uint window) { | |
uint ret = 0; | |
for(uint i = 0; i < window; i++) { | |
ret <<= 1; | |
ret |= Fq_get_bit(l, skip + i); | |
} | |
return ret; | |
} | |
DEVICE void Fq_print(Fq a) { | |
// printf("0x"); | |
// for (uint i = 0; i < Fq_LIMBS; i++) { | |
// printf("%08x", a.val[Fq_LIMBS - i - 1]); | |
// } | |
} | |
// Elliptic curve operations (Short Weierstrass Jacobian form) | |
#define G1_ZERO ((G1_projective){Fq_ZERO, Fq_ONE, Fq_ZERO}) | |
typedef struct { | |
Fq x; | |
Fq y; | |
} G1_affine; | |
typedef struct { | |
Fq x; | |
Fq y; | |
Fq z; | |
} G1_projective; | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l | |
DEVICE G1_projective G1_double(G1_projective inp) { | |
const Fq local_zero = Fq_ZERO; | |
if(Fq_eq(inp.z, local_zero)) { | |
return inp; | |
} | |
const Fq a = Fq_sqr(inp.x); // A = X1^2 | |
const Fq b = Fq_sqr(inp.y); // B = Y1^2 | |
Fq c = Fq_sqr(b); // C = B^2 | |
// D = 2*((X1+B)2-A-C) | |
Fq d = Fq_add(inp.x, b); | |
d = Fq_sqr(d); d = Fq_sub(Fq_sub(d, a), c); d = Fq_double(d); | |
const Fq e = Fq_add(Fq_double(a), a); // E = 3*A | |
const Fq f = Fq_sqr(e); | |
inp.z = Fq_mul(inp.y, inp.z); inp.z = Fq_double(inp.z); // Z3 = 2*Y1*Z1 | |
inp.x = Fq_sub(Fq_sub(f, d), d); // X3 = F-2*D | |
// Y3 = E*(D-X3)-8*C | |
c = Fq_double(c); c = Fq_double(c); c = Fq_double(c); | |
inp.y = Fq_sub(Fq_mul(Fq_sub(d, inp.x), e), c); | |
return inp; | |
} | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl | |
DEVICE G1_projective G1_add_mixed(G1_projective a, G1_affine b) { | |
const Fq local_zero = Fq_ZERO; | |
if(Fq_eq(a.z, local_zero)) { | |
const Fq local_one = Fq_ONE; | |
a.x = b.x; | |
a.y = b.y; | |
a.z = local_one; | |
return a; | |
} | |
const Fq z1z1 = Fq_sqr(a.z); | |
const Fq u2 = Fq_mul(b.x, z1z1); | |
const Fq s2 = Fq_mul(Fq_mul(b.y, a.z), z1z1); | |
if(Fq_eq(a.x, u2) && Fq_eq(a.y, s2)) { | |
return G1_double(a); | |
} | |
const Fq h = Fq_sub(u2, a.x); // H = U2-X1 | |
const Fq hh = Fq_sqr(h); // HH = H^2 | |
Fq i = Fq_double(hh); i = Fq_double(i); // I = 4*HH | |
Fq j = Fq_mul(h, i); // J = H*I | |
Fq r = Fq_sub(s2, a.y); r = Fq_double(r); // r = 2*(S2-Y1) | |
const Fq v = Fq_mul(a.x, i); | |
G1_projective ret; | |
// X3 = r^2 - J - 2*V | |
ret.x = Fq_sub(Fq_sub(Fq_sqr(r), j), Fq_double(v)); | |
// Y3 = r*(V-X3)-2*Y1*J | |
j = Fq_mul(a.y, j); j = Fq_double(j); | |
ret.y = Fq_sub(Fq_mul(Fq_sub(v, ret.x), r), j); | |
// Z3 = (Z1+H)^2-Z1Z1-HH | |
ret.z = Fq_add(a.z, h); ret.z = Fq_sub(Fq_sub(Fq_sqr(ret.z), z1z1), hh); | |
return ret; | |
} | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl | |
DEVICE G1_projective G1_add(G1_projective a, G1_projective b) { | |
const Fq local_zero = Fq_ZERO; | |
if(Fq_eq(a.z, local_zero)) return b; | |
if(Fq_eq(b.z, local_zero)) return a; | |
const Fq z1z1 = Fq_sqr(a.z); // Z1Z1 = Z1^2 | |
const Fq z2z2 = Fq_sqr(b.z); // Z2Z2 = Z2^2 | |
const Fq u1 = Fq_mul(a.x, z2z2); // U1 = X1*Z2Z2 | |
const Fq u2 = Fq_mul(b.x, z1z1); // U2 = X2*Z1Z1 | |
Fq s1 = Fq_mul(Fq_mul(a.y, b.z), z2z2); // S1 = Y1*Z2*Z2Z2 | |
const Fq s2 = Fq_mul(Fq_mul(b.y, a.z), z1z1); // S2 = Y2*Z1*Z1Z1 | |
if(Fq_eq(u1, u2) && Fq_eq(s1, s2)) | |
return G1_double(a); | |
else { | |
const Fq h = Fq_sub(u2, u1); // H = U2-U1 | |
Fq i = Fq_double(h); i = Fq_sqr(i); // I = (2*H)^2 | |
const Fq j = Fq_mul(h, i); // J = H*I | |
Fq r = Fq_sub(s2, s1); r = Fq_double(r); // r = 2*(S2-S1) | |
const Fq v = Fq_mul(u1, i); // V = U1*I | |
a.x = Fq_sub(Fq_sub(Fq_sub(Fq_sqr(r), j), v), v); // X3 = r^2 - J - 2*V | |
// Y3 = r*(V - X3) - 2*S1*J | |
a.y = Fq_mul(Fq_sub(v, a.x), r); | |
s1 = Fq_mul(s1, j); s1 = Fq_double(s1); // S1 = S1 * J * 2 | |
a.y = Fq_sub(a.y, s1); | |
// Z3 = ((Z1+Z2)^2 - Z1Z1 - Z2Z2)*H | |
a.z = Fq_add(a.z, b.z); a.z = Fq_sqr(a.z); | |
a.z = Fq_sub(Fq_sub(a.z, z1z1), z2z2); | |
a.z = Fq_mul(a.z, h); | |
return a; | |
} | |
} | |
/* | |
* Same multiexp algorithm used in Bellman, with some modifications. | |
* https://github.com/zkcrypto/bellman/blob/10c5010fd9c2ca69442dc9775ea271e286e776d8/src/multiexp.rs#L174 | |
* The CPU version of multiexp parallelism is done by dividing the exponent | |
* values into smaller windows, and then applying a sequence of rounds to each | |
* window. The GPU kernel not only assigns a thread to each window but also | |
* divides the bases into several groups which highly increases the number of | |
* threads running in parallel for calculating a multiexp instance. | |
*/ | |
KERNEL void G1_bellman_multiexp( | |
GLOBAL G1_affine *bases, | |
GLOBAL G1_projective *buckets, | |
GLOBAL G1_projective *results, | |
GLOBAL Fr *exps, | |
uint n, | |
uint num_groups, | |
uint num_windows, | |
uint window_size) { | |
// We have `num_windows` * `num_groups` threads per multiexp. | |
const uint gid = GET_GLOBAL_ID(); | |
if(gid >= num_windows * num_groups) return; | |
// We have (2^window_size - 1) buckets. | |
const uint bucket_len = ((1 << window_size) - 1); | |
// Each thread has its own set of buckets in global memory. | |
buckets += bucket_len * gid; | |
const G1_projective local_zero = G1_ZERO; | |
for(uint i = 0; i < bucket_len; i++) buckets[i] = local_zero; | |
const uint len = (uint)ceil(n / (float)num_groups); // Num of elements in each group | |
// This thread runs the multiexp algorithm on elements from `nstart` to `nened` | |
// on the window [`bits`, `bits` + `w`) | |
const uint nstart = len * (gid / num_windows); | |
const uint nend = min(nstart + len, n); | |
const uint bits = (gid % num_windows) * window_size; | |
const ushort w = min((ushort)window_size, (ushort)(Fr_BITS - bits)); | |
G1_projective res = G1_ZERO; | |
for(uint i = nstart; i < nend; i++) { | |
uint ind = Fr_get_bits(exps[i], bits, w); | |
#ifdef OPENCL_NVIDIA | |
// O_o, weird optimization, having a single special case makes it | |
// tremendously faster! | |
// 511 is chosen because it's half of the maximum bucket len, but | |
// any other number works... Bigger indices seems to be better... | |
if(ind == 511) buckets[510] = G1_add_mixed(buckets[510], bases[i]); | |
else if(ind--) buckets[ind] = G1_add_mixed(buckets[ind], bases[i]); | |
#else | |
if(ind--) buckets[ind] = G1_add_mixed(buckets[ind], bases[i]); | |
#endif | |
} | |
// Summation by parts | |
// e.g. 3a + 2b + 1c = a + | |
// (a) + b + | |
// ((a) + b) + c | |
G1_projective acc = G1_ZERO; | |
for(int j = bucket_len - 1; j >= 0; j--) { | |
acc = G1_add(acc, buckets[j]); | |
res = G1_add(res, acc); | |
} | |
results[gid] = res; | |
} | |
// Fp2 Extension Field where u^2 + 1 = 0 | |
#define Fq2_LIMB_BITS Fq_LIMB_BITS | |
#define Fq2_ZERO ((Fq2){Fq_ZERO, Fq_ZERO}) | |
#define Fq2_ONE ((Fq2){Fq_ONE, Fq_ZERO}) | |
typedef struct { | |
Fq c0; | |
Fq c1; | |
} Fq2; // Represents: c0 + u * c1 | |
DEVICE bool Fq2_eq(Fq2 a, Fq2 b) { | |
return Fq_eq(a.c0, b.c0) && Fq_eq(a.c1, b.c1); | |
} | |
DEVICE Fq2 Fq2_sub(Fq2 a, Fq2 b) { | |
a.c0 = Fq_sub(a.c0, b.c0); | |
a.c1 = Fq_sub(a.c1, b.c1); | |
return a; | |
} | |
DEVICE Fq2 Fq2_add(Fq2 a, Fq2 b) { | |
a.c0 = Fq_add(a.c0, b.c0); | |
a.c1 = Fq_add(a.c1, b.c1); | |
return a; | |
} | |
DEVICE Fq2 Fq2_double(Fq2 a) { | |
a.c0 = Fq_double(a.c0); | |
a.c1 = Fq_double(a.c1); | |
return a; | |
} | |
/* | |
* (a_0 + u * a_1)(b_0 + u * b_1) = a_0 * b_0 - a_1 * b_1 + u * (a_0 * b_1 + a_1 * b_0) | |
* Therefore: | |
* c_0 = a_0 * b_0 - a_1 * b_1 | |
* c_1 = (a_0 * b_1 + a_1 * b_0) = (a_0 + a_1) * (b_0 + b_1) - a_0 * b_0 - a_1 * b_1 | |
*/ | |
DEVICE Fq2 Fq2_mul(Fq2 a, Fq2 b) { | |
const Fq aa = Fq_mul(a.c0, b.c0); | |
const Fq bb = Fq_mul(a.c1, b.c1); | |
const Fq o = Fq_add(b.c0, b.c1); | |
a.c1 = Fq_add(a.c1, a.c0); | |
a.c1 = Fq_mul(a.c1, o); | |
a.c1 = Fq_sub(a.c1, aa); | |
a.c1 = Fq_sub(a.c1, bb); | |
a.c0 = Fq_sub(aa, bb); | |
return a; | |
} | |
/* | |
* (a_0 + u * a_1)(a_0 + u * a_1) = a_0 ^ 2 - a_1 ^ 2 + u * 2 * a_0 * a_1 | |
* Therefore: | |
* c_0 = (a_0 * a_0 - a_1 * a_1) = (a_0 + a_1)(a_0 - a_1) | |
* c_1 = 2 * a_0 * a_1 | |
*/ | |
DEVICE Fq2 Fq2_sqr(Fq2 a) { | |
const Fq ab = Fq_mul(a.c0, a.c1); | |
const Fq c0c1 = Fq_add(a.c0, a.c1); | |
a.c0 = Fq_mul(Fq_sub(a.c0, a.c1), c0c1); | |
a.c1 = Fq_double(ab); | |
return a; | |
} | |
// Elliptic curve operations (Short Weierstrass Jacobian form) | |
#define G2_ZERO ((G2_projective){Fq2_ZERO, Fq2_ONE, Fq2_ZERO}) | |
typedef struct { | |
Fq2 x; | |
Fq2 y; | |
} G2_affine; | |
typedef struct { | |
Fq2 x; | |
Fq2 y; | |
Fq2 z; | |
} G2_projective; | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#doubling-dbl-2009-l | |
DEVICE G2_projective G2_double(G2_projective inp) { | |
const Fq2 local_zero = Fq2_ZERO; | |
if(Fq2_eq(inp.z, local_zero)) { | |
return inp; | |
} | |
const Fq2 a = Fq2_sqr(inp.x); // A = X1^2 | |
const Fq2 b = Fq2_sqr(inp.y); // B = Y1^2 | |
Fq2 c = Fq2_sqr(b); // C = B^2 | |
// D = 2*((X1+B)2-A-C) | |
Fq2 d = Fq2_add(inp.x, b); | |
d = Fq2_sqr(d); d = Fq2_sub(Fq2_sub(d, a), c); d = Fq2_double(d); | |
const Fq2 e = Fq2_add(Fq2_double(a), a); // E = 3*A | |
const Fq2 f = Fq2_sqr(e); | |
inp.z = Fq2_mul(inp.y, inp.z); inp.z = Fq2_double(inp.z); // Z3 = 2*Y1*Z1 | |
inp.x = Fq2_sub(Fq2_sub(f, d), d); // X3 = F-2*D | |
// Y3 = E*(D-X3)-8*C | |
c = Fq2_double(c); c = Fq2_double(c); c = Fq2_double(c); | |
inp.y = Fq2_sub(Fq2_mul(Fq2_sub(d, inp.x), e), c); | |
return inp; | |
} | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl | |
DEVICE G2_projective G2_add_mixed(G2_projective a, G2_affine b) { | |
const Fq2 local_zero = Fq2_ZERO; | |
if(Fq2_eq(a.z, local_zero)) { | |
const Fq2 local_one = Fq2_ONE; | |
a.x = b.x; | |
a.y = b.y; | |
a.z = local_one; | |
return a; | |
} | |
const Fq2 z1z1 = Fq2_sqr(a.z); | |
const Fq2 u2 = Fq2_mul(b.x, z1z1); | |
const Fq2 s2 = Fq2_mul(Fq2_mul(b.y, a.z), z1z1); | |
if(Fq2_eq(a.x, u2) && Fq2_eq(a.y, s2)) { | |
return G2_double(a); | |
} | |
const Fq2 h = Fq2_sub(u2, a.x); // H = U2-X1 | |
const Fq2 hh = Fq2_sqr(h); // HH = H^2 | |
Fq2 i = Fq2_double(hh); i = Fq2_double(i); // I = 4*HH | |
Fq2 j = Fq2_mul(h, i); // J = H*I | |
Fq2 r = Fq2_sub(s2, a.y); r = Fq2_double(r); // r = 2*(S2-Y1) | |
const Fq2 v = Fq2_mul(a.x, i); | |
G2_projective ret; | |
// X3 = r^2 - J - 2*V | |
ret.x = Fq2_sub(Fq2_sub(Fq2_sqr(r), j), Fq2_double(v)); | |
// Y3 = r*(V-X3)-2*Y1*J | |
j = Fq2_mul(a.y, j); j = Fq2_double(j); | |
ret.y = Fq2_sub(Fq2_mul(Fq2_sub(v, ret.x), r), j); | |
// Z3 = (Z1+H)^2-Z1Z1-HH | |
ret.z = Fq2_add(a.z, h); ret.z = Fq2_sub(Fq2_sub(Fq2_sqr(ret.z), z1z1), hh); | |
return ret; | |
} | |
// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl | |
DEVICE G2_projective G2_add(G2_projective a, G2_projective b) { | |
const Fq2 local_zero = Fq2_ZERO; | |
if(Fq2_eq(a.z, local_zero)) return b; | |
if(Fq2_eq(b.z, local_zero)) return a; | |
const Fq2 z1z1 = Fq2_sqr(a.z); // Z1Z1 = Z1^2 | |
const Fq2 z2z2 = Fq2_sqr(b.z); // Z2Z2 = Z2^2 | |
const Fq2 u1 = Fq2_mul(a.x, z2z2); // U1 = X1*Z2Z2 | |
const Fq2 u2 = Fq2_mul(b.x, z1z1); // U2 = X2*Z1Z1 | |
Fq2 s1 = Fq2_mul(Fq2_mul(a.y, b.z), z2z2); // S1 = Y1*Z2*Z2Z2 | |
const Fq2 s2 = Fq2_mul(Fq2_mul(b.y, a.z), z1z1); // S2 = Y2*Z1*Z1Z1 | |
if(Fq2_eq(u1, u2) && Fq2_eq(s1, s2)) | |
return G2_double(a); | |
else { | |
const Fq2 h = Fq2_sub(u2, u1); // H = U2-U1 | |
Fq2 i = Fq2_double(h); i = Fq2_sqr(i); // I = (2*H)^2 | |
const Fq2 j = Fq2_mul(h, i); // J = H*I | |
Fq2 r = Fq2_sub(s2, s1); r = Fq2_double(r); // r = 2*(S2-S1) | |
const Fq2 v = Fq2_mul(u1, i); // V = U1*I | |
a.x = Fq2_sub(Fq2_sub(Fq2_sub(Fq2_sqr(r), j), v), v); // X3 = r^2 - J - 2*V | |
// Y3 = r*(V - X3) - 2*S1*J | |
a.y = Fq2_mul(Fq2_sub(v, a.x), r); | |
s1 = Fq2_mul(s1, j); s1 = Fq2_double(s1); // S1 = S1 * J * 2 | |
a.y = Fq2_sub(a.y, s1); | |
// Z3 = ((Z1+Z2)^2 - Z1Z1 - Z2Z2)*H | |
a.z = Fq2_add(a.z, b.z); a.z = Fq2_sqr(a.z); | |
a.z = Fq2_sub(Fq2_sub(a.z, z1z1), z2z2); | |
a.z = Fq2_mul(a.z, h); | |
return a; | |
} | |
} | |
/* | |
* Same multiexp algorithm used in Bellman, with some modifications. | |
* https://github.com/zkcrypto/bellman/blob/10c5010fd9c2ca69442dc9775ea271e286e776d8/src/multiexp.rs#L174 | |
* The CPU version of multiexp parallelism is done by dividing the exponent | |
* values into smaller windows, and then applying a sequence of rounds to each | |
* window. The GPU kernel not only assigns a thread to each window but also | |
* divides the bases into several groups which highly increases the number of | |
* threads running in parallel for calculating a multiexp instance. | |
*/ | |
KERNEL void G2_bellman_multiexp( | |
GLOBAL G2_affine *bases, | |
GLOBAL G2_projective *buckets, | |
GLOBAL G2_projective *results, | |
GLOBAL Fr *exps, | |
uint n, | |
uint num_groups, | |
uint num_windows, | |
uint window_size) { | |
// We have `num_windows` * `num_groups` threads per multiexp. | |
const uint gid = GET_GLOBAL_ID(); | |
if(gid >= num_windows * num_groups) return; | |
// We have (2^window_size - 1) buckets. | |
const uint bucket_len = ((1 << window_size) - 1); | |
// Each thread has its own set of buckets in global memory. | |
buckets += bucket_len * gid; | |
const G2_projective local_zero = G2_ZERO; | |
for(uint i = 0; i < bucket_len; i++) buckets[i] = local_zero; | |
const uint len = (uint)ceil(n / (float)num_groups); // Num of elements in each group | |
// This thread runs the multiexp algorithm on elements from `nstart` to `nened` | |
// on the window [`bits`, `bits` + `w`) | |
const uint nstart = len * (gid / num_windows); | |
const uint nend = min(nstart + len, n); | |
const uint bits = (gid % num_windows) * window_size; | |
const ushort w = min((ushort)window_size, (ushort)(Fr_BITS - bits)); | |
G2_projective res = G2_ZERO; | |
for(uint i = nstart; i < nend; i++) { | |
uint ind = Fr_get_bits(exps[i], bits, w); | |
#ifdef OPENCL_NVIDIA | |
// O_o, weird optimization, having a single special case makes it | |
// tremendously faster! | |
// 511 is chosen because it's half of the maximum bucket len, but | |
// any other number works... Bigger indices seems to be better... | |
if(ind == 511) buckets[510] = G2_add_mixed(buckets[510], bases[i]); | |
else if(ind--) buckets[ind] = G2_add_mixed(buckets[ind], bases[i]); | |
#else | |
if(ind--) buckets[ind] = G2_add_mixed(buckets[ind], bases[i]); | |
#endif | |
} | |
// Summation by parts | |
// e.g. 3a + 2b + 1c = a + | |
// (a) + b + | |
// ((a) + b) + c | |
G2_projective acc = G2_ZERO; | |
for(int j = bucket_len - 1; j >= 0; j--) { | |
acc = G2_add(acc, buckets[j]); | |
res = G2_add(res, acc); | |
} | |
results[gid] = res; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment