Created
February 20, 2025 02:04
-
-
Save apivovarov/ae183980c31c4f419c47f758e7437ef4 to your computer and use it in GitHub Desktop.
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdint> | |
#include <limits> | |
#include <vector> | |
#include <utility> | |
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) { | |
// Handle division by zero | |
uint32_t is_zero_divisor = (divisor == 0); | |
if (is_zero_divisor) { | |
return 0; // HLO: select(is_zero_divisor, 0, ...) | |
} | |
// Handle special cases | |
uint32_t is_dividend_zero = (dividend == 0); | |
uint32_t is_divisor_one = (divisor == 1); | |
uint32_t is_dividend_less_than_divisor = (dividend < divisor); | |
// Early return for special cases (HLO: select operations) | |
if (is_dividend_zero) { | |
return 0; // HLO: select(is_dividend_zero, 0, ...) | |
} | |
if (is_divisor_one) { | |
return dividend; // HLO: select(is_divisor_one, dividend, ...) | |
} | |
if (is_dividend_less_than_divisor) { | |
return 0; // HLO: select(is_dividend_less_than_divisor, 0, ...) | |
} | |
// Split dividend into 16-bit halves | |
const uint32_t mask16 = 0xFFFF; | |
uint32_t a_high = dividend >> 16; // Upper 16 bits of dividend | |
uint32_t a_low = dividend & mask16; // Lower 16 bits of dividend | |
// Convert to float32 for division | |
float a_high_f = static_cast<float>(a_high); | |
float a_low_f = static_cast<float>(a_low); | |
float b_f = static_cast<float>(divisor); | |
// Estimate quotient for high part (HLO: divide) | |
// Scale by 2^16 to account for high part | |
float q_high_f = a_high_f * 65536.0f / b_f; | |
uint32_t q_high = static_cast<uint32_t>(q_high_f); | |
// Refine quotient for high part (HLO: multiply, subtract, compare, select) | |
uint32_t product_high = q_high * divisor; | |
uint32_t remainder_high = dividend - product_high; | |
// Perform fixed number of refinement steps (no loops) | |
uint32_t needs_adjustment_high_1 = (remainder_high >= divisor); | |
uint32_t adjustment_high_1 = needs_adjustment_high_1; | |
q_high += adjustment_high_1; | |
remainder_high -= adjustment_high_1 * divisor; | |
uint32_t needs_adjustment_high_2 = (remainder_high >= divisor); | |
uint32_t adjustment_high_2 = needs_adjustment_high_2; | |
q_high += adjustment_high_2; | |
remainder_high -= adjustment_high_2 * divisor; | |
uint32_t needs_adjustment_high_3 = (remainder_high >= divisor); | |
uint32_t adjustment_high_3 = needs_adjustment_high_3; | |
q_high += adjustment_high_3; | |
remainder_high -= adjustment_high_3 * divisor; | |
// Estimate quotient for low part (HLO: divide) | |
float remainder_high_f = static_cast<float>(remainder_high); | |
float q_low_f = remainder_high_f / b_f; | |
uint32_t q_low = static_cast<uint32_t>(q_low_f); | |
// Refine quotient for low part (HLO: multiply, subtract, compare, select) | |
uint32_t product_low = q_low * divisor; | |
uint32_t remainder_low = remainder_high - product_low; | |
// Perform fixed number of refinement steps (no loops) | |
uint32_t needs_adjustment_low_1 = (remainder_low >= divisor); | |
uint32_t adjustment_low_1 = needs_adjustment_low_1; | |
q_low += adjustment_low_1; | |
remainder_low -= adjustment_low_1 * divisor; | |
uint32_t needs_adjustment_low_2 = (remainder_low >= divisor); | |
uint32_t adjustment_low_2 = needs_adjustment_low_2; | |
q_low += adjustment_low_2; | |
remainder_low -= adjustment_low_2 * divisor; | |
uint32_t needs_adjustment_low_3 = (remainder_low >= divisor); | |
uint32_t adjustment_low_3 = needs_adjustment_low_3; | |
q_low += adjustment_low_3; | |
remainder_low -= adjustment_low_3 * divisor; | |
// Combine quotients (HLO: add) | |
uint32_t quotient = q_high + q_low; | |
return quotient; | |
} | |
int main() { | |
std::vector<std::pair<uint32_t, uint32_t>> aa={ | |
{16841473, 6571}, | |
{16846599, 6573}, | |
{167704699, 6541}, | |
{167755977, 6543}, | |
{2147483647, 6}, | |
{357913941, 993}, | |
{2147483641, 2699}, | |
{16841473, 10}, | |
{16846599, 10}, | |
{167704699, 10}, | |
{167755977, 10}, | |
{2147483647, 10}, | |
{357913941, 10}, | |
{2147483641, 10}, | |
}; | |
for (auto a : aa) { | |
uint32_t c = divide_uint32(a.first, a.second); | |
uint32_t ec = a.first / a.second; | |
std::cout << a.first << " / " << a.second << " = " << c << std::endl; | |
if (c != ec) { | |
std::cout << "Error: res / exp: " << c << " / " << ec << std::endl; | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment