Created
April 1, 2020 13:58
-
-
Save zac-williamson/6c0e08db3f0621d9e3d3f0264403b9f3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// requires barretenberg numeric dependency | |
#include <stdio.h> | |
#include <math.h> | |
#include <inttypes.h> | |
#include <x86intrin.h> | |
#include <stdlib.h> | |
#include <numeric/random/engine.hpp> | |
namespace { | |
auto& engine = numeric::random::get_debug_engine(); | |
} | |
//#include <float.h> | |
typedef struct { | |
float hi; | |
float lo; | |
} doublefloat; | |
doublefloat quick_two_sum(float a, float b) | |
{ | |
float s = a + b; | |
float e = b - (s - a); | |
return (doublefloat){ s, e }; | |
// 3 add | |
} | |
doublefloat two_sum(float a, float b) | |
{ | |
return quick_two_sum(std::max(a, b), std::min(a, b)); | |
// 5 ops | |
} | |
doublefloat add_doublefloats(doublefloat a, doublefloat b) | |
{ | |
float r = a.hi + b.hi; | |
bool gt = fabsf(a.hi) > fabsf(b.hi); | |
doublefloat left = gt ? a : b; | |
doublefloat right = gt ? b : a; | |
float s = ((left.hi - r + right.hi) + right.lo) + left.lo; | |
return quick_two_sum(std::max(r, s), std::min(r, s)); | |
// 15 ops | |
} | |
doublefloat split(float a) | |
{ | |
constexpr float splitter = (float)((1 << 12) + 1); | |
float t = splitter * a; | |
float hi = t - (t - a); | |
float lo = a - hi; | |
return (doublefloat){ hi, lo }; | |
// 4 ops | |
} | |
doublefloat two_product(float a, float b) | |
{ | |
float x = a * b; | |
doublefloat as = split(a); | |
doublefloat bs = split(b); | |
float err1 = x - (as.hi * bs.hi); | |
float err2 = err1 - (as.lo * bs.hi); | |
float err3 = err2 - (as.hi * bs.lo); | |
float y = as.lo * bs.lo - err3; | |
return { x, y }; | |
// 17 ops | |
} | |
void mac(float a, float b, float c, float carry_in, float& out, float& carry_out) | |
{ | |
auto res = two_product(b, c); // 17 | |
const auto t0 = two_sum(a, carry_in); // 5 | |
res = add_doublefloats(res, t0); // 15 | |
if (res.lo < 0) { | |
res.lo += (float)(1ULL << 23ULL); | |
res.hi -= (float)(1ULL << 23ULL); | |
} // 4 | |
float x = res.hi + (float)(1ULL << 46ULL); | |
float u = x - (float)(1ULL << 46ULL); | |
float v = res.hi - u; | |
u = u * ((float)1 / (float)(1ULL << 23ULL)); // 4 | |
if (v < 0) { | |
v += (float)(1ULL << 23ULL); | |
u -= (float)1; | |
} // 4 | |
out = v + res.lo; | |
if (out >= (float)(1ULL << 23ULL)) { | |
res.lo -= (float)(1ULL << 23ULL); | |
out = v + res.lo; | |
u += 1; | |
} // 5 | |
carry_out = u; | |
// 54 ops | |
} | |
void mul_test(float* a, float* b, float* out) | |
{ | |
for (size_t i = 0; i < 24; ++i) { | |
out[i] = 0; | |
} | |
float carry2 = 0; | |
for (size_t i = 0; i < 12; ++i) { | |
float carry = 0; | |
for (size_t j = 0; j < 11; ++j) { | |
mac(out[i + j], a[i], b[j], carry, out[i + j], carry); | |
} | |
mac(carry2, a[i], b[11], carry, out[i + 11], carry2); | |
} | |
out[23] = carry2; | |
} | |
void convert_into_floats(uint256_t& input, float* output) | |
{ | |
constexpr uint64_t bit_mask = (1UL << 23UL) - 1UL; | |
output[0] = (float)(input.data[0] & bit_mask); | |
output[1] = (float)((input.data[0] >> 23) & bit_mask); | |
output[2] = (float)((input.data[0] >> 46) + ((input.data[1] & ((1ULL << 5ULL) - 1ULL)) << 18)); | |
output[3] = (float)((input.data[1] >> 5) & bit_mask); | |
output[4] = (float)((input.data[1] >> 28) & bit_mask); | |
output[5] = (float)((input.data[1] >> 51) + ((input.data[2] & ((1ULL << 10ULL) - 1ULL)) << 13)); | |
output[6] = (float)((input.data[2] >> 10) & bit_mask); | |
output[7] = (float)((input.data[2] >> 33) & bit_mask); | |
output[8] = (float)((input.data[2] >> 56) + ((input.data[3] & ((1ULL << 15ULL) - 1ULL)) << 8)); | |
output[9] = (float)((input.data[3] >> 15) & bit_mask); | |
output[10] = (float)((input.data[3] >> 38) & bit_mask); | |
output[11] = (float)((input.data[3] >> 61)); | |
} | |
void convert_into_ints(float* input, uint512_t& output) | |
{ | |
uint64_t t0 = (uint64_t)(input[0]) + ((uint64_t)(input[1]) << 23ULL); | |
t0 += (((uint64_t)input[2]) << 46ULL); | |
uint64_t t1 = ((uint64_t)(input[2]) >> 18ULL); | |
t1 += ((uint64_t)(input[3]) << 5ULL); | |
t1 += ((uint64_t)(input[4]) << 28ULL); | |
t1 += ((uint64_t)(input[5]) << 51ULL); | |
uint64_t t2 = ((uint64_t)(input[5]) >> 13ULL); | |
t2 += ((uint64_t)(input[6]) << 10ULL); | |
t2 += ((uint64_t)(input[7]) << 33ULL); | |
t2 += ((uint64_t)(input[8]) << 56ULL); | |
uint64_t t3 = ((uint64_t)(input[8]) >> 8ULL); | |
t3 += ((uint64_t)(input[9]) << 15ULL); | |
t3 += ((uint64_t)(input[10]) << 38ULL); | |
t3 += ((uint64_t)(input[11]) << 61ULL); | |
uint64_t t4 = ((uint64_t)(input[11]) >> 3ULL); | |
t4 += ((uint64_t)(input[12]) << 20ULL); | |
t4 += ((uint64_t)(input[13]) << 43ULL); | |
uint64_t t5 = ((uint64_t)(input[13]) >> 21ULL); | |
t5 += ((uint64_t)(input[14]) << 2ULL); | |
t5 += ((uint64_t)(input[15]) << 25ULL); | |
t5 += ((uint64_t)(input[16]) << 48ULL); | |
uint64_t t6 = ((uint64_t)(input[16]) >> 16ULL); | |
t6 += ((uint64_t)(input[17]) << 7ULL); | |
t6 += ((uint64_t)(input[18]) << 30ULL); | |
t6 += ((uint64_t)(input[19]) << 53ULL); | |
uint64_t t7 = ((uint64_t)(input[19]) >> 11ULL); | |
t7 += ((uint64_t)(input[20]) << 12ULL); | |
t7 += ((uint64_t)(input[21]) << 35ULL); | |
t7 += ((uint64_t)(input[22]) << 58ULL); | |
output.lo.data[0] = t0; | |
output.lo.data[1] = t1; | |
output.lo.data[2] = t2; | |
output.lo.data[3] = t3; | |
output.hi.data[0] = t4; | |
output.hi.data[1] = t5; | |
output.hi.data[2] = t6; | |
output.hi.data[3] = t7; | |
} | |
int main(void) | |
{ | |
bool valid = true; | |
std::cout << "testing 1,000 256x256->512 bit muls" << std::endl; | |
for (size_t i = 0; i < 1000; ++i) { | |
uint256_t left = engine.get_random_uint256(); | |
uint256_t right = engine.get_random_uint256(); | |
uint512_t expected = uint512_t(left) * uint512_t(right); | |
uint512_t result; | |
float left_floats[12]; | |
float right_floats[12]; | |
float output_floats[24]; | |
convert_into_floats(left, left_floats); | |
convert_into_floats(right, right_floats); | |
mul_test(left_floats, right_floats, output_floats); | |
convert_into_ints(output_floats, result); | |
if (result != expected) { | |
valid = false; | |
} | |
} | |
if (valid) { | |
std::cout << "pass" << std::endl; | |
} else { | |
std::cout << "fail" << std::endl; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment