Skip to content

Instantly share code, notes, and snippets.

@FawadHa1der
Created November 19, 2024 17:20
Show Gist options
  • Save FawadHa1der/6a230bcfb3bf0e49764a9547ba254d79 to your computer and use it in GitHub Desktop.
Save FawadHa1der/6a230bcfb3bf0e49764a9547ba254d79 to your computer and use it in GitHub Desktop.
Create a Binius tower mul XAG(XOR AND GRAPH) circuit unrolled using MockTurtle
// Build and simulate XAG network for binary multiplication
std::vector<xag_network::signal> binius_mul_xag(xag_network& xag, const std::vector<xag_network::signal>& v1, const std::vector<xag_network::signal>& v2, uint32_t length, bool is_constant)
{
assert(v1.size() == length && v2.size() == length); // Ensure input vectors match the bit length
// Base case: AND the single bits
if (length == 1) {
AND_count++;
return {xag.create_and(v1[0], v2[0])};
}
uint32_t halflen = length / 2;
uint32_t quarterlen = length / 4;
// Split v1 and v2 into L1, R1 and L2, R2
std::vector<xag_network::signal> L1(v1.begin(), v1.begin() + halflen);
std::vector<xag_network::signal> R1(v1.begin() + halflen, v1.end());
std::vector<xag_network::signal> L2(v2.begin(), v2.begin() + halflen);
std::vector<xag_network::signal> R2(v2.begin() + halflen, v2.end());
// Handle the condition: (L1 == 0 && R1 == 1 && is_constant == true)
if (is_constant && std::all_of(L1.begin(), L1.end(), [&](auto s){ return s == xag.get_constant(false); }) && R1[0] == xag.get_constant(true)) {
// Combine the upper and lower parts
std::vector<xag_network::signal> result(length);
if (length == 2) { // special case for 2-bit multiplication with a 1 constant
std::cout << " in the optimized AND section"<< std::endl;
result[1] = xag.create_xor(R2[0], L2[0]);
result[0] = R2[0]; // Lower bits as R2
return result;
// return {xag.create_and(outR_input[0], R2[0])};
}
else
{
std::vector<xag_network::signal> outR_input(halflen, xag.get_constant(false));
if (quarterlen < halflen) {
outR_input[quarterlen] = xag.get_constant(true); // 1ULL << quarterlen
}
// Recursive call with outR_input and R2
auto outR = binius_mul_xag(xag, outR_input, R2, halflen, true);
// XOR outR with L2
for (size_t i = 0; i < halflen; ++i) {
outR[i] = xag.create_xor(outR[i], L2[i]);
}
for (size_t i = 0; i < halflen; ++i) {
result[i + halflen] = outR[i];
result[i] = R2[i]; // Lower bits as R2
}
return result;
}
}
// Recursive calls for L1L2 and R1R2
auto L1L2 = binius_mul_xag(xag, L1, L2, halflen, false);
auto R1R2 = binius_mul_xag(xag, R1, R2, halflen, false);
std::vector<xag_network::signal> Z3_input_v1(halflen), Z3_input_v2(halflen);
for (size_t i = 0; i < halflen; ++i) {
Z3_input_v1[i] = xag.create_xor(L1[i], R1[i]);
Z3_input_v2[i] = xag.create_xor(L2[i], R2[i]);
}
if (halflen == 1) { // this means R1R2_high_input is one.
// Recursive call for Z3 (multiplication of XOR'd inputs)
auto Z3 = binius_mul_xag(xag, Z3_input_v1, Z3_input_v2, halflen, false);
// Combine the results: XOR the signals and shift
std::vector<xag_network::signal> result(length);
// Propagate lower bits from L1L2 directly
for (size_t i = 0; i < halflen; ++i) {
result[i] = xag.create_xor(L1L2[i], R1R2[i]);
}
result[1] = xag.create_xor(Z3[0], L1L2[0]);
return result;
}
else{
// Handle R1R2_high_input: (1ULL << quarterlen) * R1R2
std::vector<xag_network::signal> R1R2_high_input(halflen, xag.get_constant(false));
if (quarterlen < halflen) {
R1R2_high_input[quarterlen] = xag.get_constant(true); // Set the (quarterlen)-th bit to 1
}
auto R1R2_high = binius_mul_xag(xag, R1R2_high_input, R1R2, halflen, true);
// XOR gates for Z3 inputs (L1 ^ R1) and (L2 ^ R2)
// Recursive call for Z3 (multiplication of XOR'd inputs)
auto Z3 = binius_mul_xag(xag, Z3_input_v1, Z3_input_v2, halflen, false);
// Combine the results: XOR the signals and shift
std::vector<xag_network::signal> result(length);
// Propagate lower bits from L1L2 directly
for (size_t i = 0; i < halflen; ++i) {
result[i] = xag.create_xor(L1L2[i], R1R2[i]);
}
for (size_t i = 0; i < halflen; ++i) {
// auto upper_combination = xag.create_xor(Z3[i], L1L2[i]);
auto upper_combination = xag.create_xor(result[i], Z3[i]);
upper_combination = xag.create_xor(upper_combination, R1R2_high[i]);
result[i + halflen] = upper_combination;
}
return result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment