Created
November 19, 2024 17:20
-
-
Save FawadHa1der/6a230bcfb3bf0e49764a9547ba254d79 to your computer and use it in GitHub Desktop.
Create a Binius tower mul XAG(XOR AND GRAPH) circuit unrolled using MockTurtle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Build and simulate XAG network for binary multiplication | |
std::vector<xag_network::signal> binius_mul_xag(xag_network& xag, const std::vector<xag_network::signal>& v1, const std::vector<xag_network::signal>& v2, uint32_t length, bool is_constant) | |
{ | |
assert(v1.size() == length && v2.size() == length); // Ensure input vectors match the bit length | |
// Base case: AND the single bits | |
if (length == 1) { | |
AND_count++; | |
return {xag.create_and(v1[0], v2[0])}; | |
} | |
uint32_t halflen = length / 2; | |
uint32_t quarterlen = length / 4; | |
// Split v1 and v2 into L1, R1 and L2, R2 | |
std::vector<xag_network::signal> L1(v1.begin(), v1.begin() + halflen); | |
std::vector<xag_network::signal> R1(v1.begin() + halflen, v1.end()); | |
std::vector<xag_network::signal> L2(v2.begin(), v2.begin() + halflen); | |
std::vector<xag_network::signal> R2(v2.begin() + halflen, v2.end()); | |
// Handle the condition: (L1 == 0 && R1 == 1 && is_constant == true) | |
if (is_constant && std::all_of(L1.begin(), L1.end(), [&](auto s){ return s == xag.get_constant(false); }) && R1[0] == xag.get_constant(true)) { | |
// Combine the upper and lower parts | |
std::vector<xag_network::signal> result(length); | |
if (length == 2) { // special case for 2-bit multiplication with a 1 constant | |
std::cout << " in the optimized AND section"<< std::endl; | |
result[1] = xag.create_xor(R2[0], L2[0]); | |
result[0] = R2[0]; // Lower bits as R2 | |
return result; | |
// return {xag.create_and(outR_input[0], R2[0])}; | |
} | |
else | |
{ | |
std::vector<xag_network::signal> outR_input(halflen, xag.get_constant(false)); | |
if (quarterlen < halflen) { | |
outR_input[quarterlen] = xag.get_constant(true); // 1ULL << quarterlen | |
} | |
// Recursive call with outR_input and R2 | |
auto outR = binius_mul_xag(xag, outR_input, R2, halflen, true); | |
// XOR outR with L2 | |
for (size_t i = 0; i < halflen; ++i) { | |
outR[i] = xag.create_xor(outR[i], L2[i]); | |
} | |
for (size_t i = 0; i < halflen; ++i) { | |
result[i + halflen] = outR[i]; | |
result[i] = R2[i]; // Lower bits as R2 | |
} | |
return result; | |
} | |
} | |
// Recursive calls for L1L2 and R1R2 | |
auto L1L2 = binius_mul_xag(xag, L1, L2, halflen, false); | |
auto R1R2 = binius_mul_xag(xag, R1, R2, halflen, false); | |
std::vector<xag_network::signal> Z3_input_v1(halflen), Z3_input_v2(halflen); | |
for (size_t i = 0; i < halflen; ++i) { | |
Z3_input_v1[i] = xag.create_xor(L1[i], R1[i]); | |
Z3_input_v2[i] = xag.create_xor(L2[i], R2[i]); | |
} | |
if (halflen == 1) { // this means R1R2_high_input is one. | |
// Recursive call for Z3 (multiplication of XOR'd inputs) | |
auto Z3 = binius_mul_xag(xag, Z3_input_v1, Z3_input_v2, halflen, false); | |
// Combine the results: XOR the signals and shift | |
std::vector<xag_network::signal> result(length); | |
// Propagate lower bits from L1L2 directly | |
for (size_t i = 0; i < halflen; ++i) { | |
result[i] = xag.create_xor(L1L2[i], R1R2[i]); | |
} | |
result[1] = xag.create_xor(Z3[0], L1L2[0]); | |
return result; | |
} | |
else{ | |
// Handle R1R2_high_input: (1ULL << quarterlen) * R1R2 | |
std::vector<xag_network::signal> R1R2_high_input(halflen, xag.get_constant(false)); | |
if (quarterlen < halflen) { | |
R1R2_high_input[quarterlen] = xag.get_constant(true); // Set the (quarterlen)-th bit to 1 | |
} | |
auto R1R2_high = binius_mul_xag(xag, R1R2_high_input, R1R2, halflen, true); | |
// XOR gates for Z3 inputs (L1 ^ R1) and (L2 ^ R2) | |
// Recursive call for Z3 (multiplication of XOR'd inputs) | |
auto Z3 = binius_mul_xag(xag, Z3_input_v1, Z3_input_v2, halflen, false); | |
// Combine the results: XOR the signals and shift | |
std::vector<xag_network::signal> result(length); | |
// Propagate lower bits from L1L2 directly | |
for (size_t i = 0; i < halflen; ++i) { | |
result[i] = xag.create_xor(L1L2[i], R1R2[i]); | |
} | |
for (size_t i = 0; i < halflen; ++i) { | |
// auto upper_combination = xag.create_xor(Z3[i], L1L2[i]); | |
auto upper_combination = xag.create_xor(result[i], Z3[i]); | |
upper_combination = xag.create_xor(upper_combination, R1R2_high[i]); | |
result[i + halflen] = upper_combination; | |
} | |
return result; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment