Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Created February 20, 2025 02:04
Show Gist options
  • Save apivovarov/ae183980c31c4f419c47f758e7437ef4 to your computer and use it in GitHub Desktop.
Save apivovarov/ae183980c31c4f419c47f758e7437ef4 to your computer and use it in GitHub Desktop.
Function to perform uint32 division without loops using float32 div, uint32 add, sub, mul
#include <iostream>
#include <cstdint>
#include <limits>
#include <vector>
#include <utility>
uint32_t divide_uint32(uint32_t dividend, uint32_t divisor) {
// Handle division by zero
uint32_t is_zero_divisor = (divisor == 0);
if (is_zero_divisor) {
return 0; // HLO: select(is_zero_divisor, 0, ...)
}
// Handle special cases
uint32_t is_dividend_zero = (dividend == 0);
uint32_t is_divisor_one = (divisor == 1);
uint32_t is_dividend_less_than_divisor = (dividend < divisor);
// Early return for special cases (HLO: select operations)
if (is_dividend_zero) {
return 0; // HLO: select(is_dividend_zero, 0, ...)
}
if (is_divisor_one) {
return dividend; // HLO: select(is_divisor_one, dividend, ...)
}
if (is_dividend_less_than_divisor) {
return 0; // HLO: select(is_dividend_less_than_divisor, 0, ...)
}
// Split dividend into 16-bit halves
const uint32_t mask16 = 0xFFFF;
uint32_t a_high = dividend >> 16; // Upper 16 bits of dividend
uint32_t a_low = dividend & mask16; // Lower 16 bits of dividend
// Convert to float32 for division
float a_high_f = static_cast<float>(a_high);
float a_low_f = static_cast<float>(a_low);
float b_f = static_cast<float>(divisor);
// Estimate quotient for high part (HLO: divide)
// Scale by 2^16 to account for high part
float q_high_f = a_high_f * 65536.0f / b_f;
uint32_t q_high = static_cast<uint32_t>(q_high_f);
// Refine quotient for high part (HLO: multiply, subtract, compare, select)
uint32_t product_high = q_high * divisor;
uint32_t remainder_high = dividend - product_high;
// Perform fixed number of refinement steps (no loops)
uint32_t needs_adjustment_high_1 = (remainder_high >= divisor);
uint32_t adjustment_high_1 = needs_adjustment_high_1;
q_high += adjustment_high_1;
remainder_high -= adjustment_high_1 * divisor;
uint32_t needs_adjustment_high_2 = (remainder_high >= divisor);
uint32_t adjustment_high_2 = needs_adjustment_high_2;
q_high += adjustment_high_2;
remainder_high -= adjustment_high_2 * divisor;
uint32_t needs_adjustment_high_3 = (remainder_high >= divisor);
uint32_t adjustment_high_3 = needs_adjustment_high_3;
q_high += adjustment_high_3;
remainder_high -= adjustment_high_3 * divisor;
// Estimate quotient for low part (HLO: divide)
float remainder_high_f = static_cast<float>(remainder_high);
float q_low_f = remainder_high_f / b_f;
uint32_t q_low = static_cast<uint32_t>(q_low_f);
// Refine quotient for low part (HLO: multiply, subtract, compare, select)
uint32_t product_low = q_low * divisor;
uint32_t remainder_low = remainder_high - product_low;
// Perform fixed number of refinement steps (no loops)
uint32_t needs_adjustment_low_1 = (remainder_low >= divisor);
uint32_t adjustment_low_1 = needs_adjustment_low_1;
q_low += adjustment_low_1;
remainder_low -= adjustment_low_1 * divisor;
uint32_t needs_adjustment_low_2 = (remainder_low >= divisor);
uint32_t adjustment_low_2 = needs_adjustment_low_2;
q_low += adjustment_low_2;
remainder_low -= adjustment_low_2 * divisor;
uint32_t needs_adjustment_low_3 = (remainder_low >= divisor);
uint32_t adjustment_low_3 = needs_adjustment_low_3;
q_low += adjustment_low_3;
remainder_low -= adjustment_low_3 * divisor;
// Combine quotients (HLO: add)
uint32_t quotient = q_high + q_low;
return quotient;
}
int main() {
std::vector<std::pair<uint32_t, uint32_t>> aa={
{16841473, 6571},
{16846599, 6573},
{167704699, 6541},
{167755977, 6543},
{2147483647, 6},
{357913941, 993},
{2147483641, 2699},
{16841473, 10},
{16846599, 10},
{167704699, 10},
{167755977, 10},
{2147483647, 10},
{357913941, 10},
{2147483641, 10},
};
for (auto a : aa) {
uint32_t c = divide_uint32(a.first, a.second);
uint32_t ec = a.first / a.second;
std::cout << a.first << " / " << a.second << " = " << c << std::endl;
if (c != ec) {
std::cout << "Error: res / exp: " << c << " / " << ec << std::endl;
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment