Skip to content

Instantly share code, notes, and snippets.

@frangio
Last active September 26, 2022 21:57
Show Gist options
  • Save frangio/2cfcda1ce297dc04b74d31762452b157 to your computer and use it in GitHub Desktop.
Save frangio/2cfcda1ce297dc04b74d31762452b157 to your computer and use it in GitHub Desktop.
Multiply-and-divide without intermediate overflow, without particularly clever algorithms.
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
error Overflow();
/// Multiply-and-divide without intermediate overflow.
/// @return r = (x * y) / d
function mulDiv(uint x, uint y, uint d) pure returns (uint r) { unchecked {
(uint z1, uint z0) = mul_1_1_2(x, y);
(r,) = div_2_1_1(z1, z0, d);
}}
/// Multiplies 1 word by 1 word and returns 2 word result. No overflow.
function mul_1_1_2(uint x, uint y) pure returns (uint r1, uint r0) { unchecked {
(uint x1, uint x0) = (x >> 128, x & type(uint128).max);
(uint y1, uint y0) = (y >> 128, y & type(uint128).max);
uint z2 = x1 * y1;
uint z1a = x1 * y0;
uint z1b = x0 * y1;
uint z0 = x0 * y0;
uint carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128;
r1 = z2 + (z1a >> 128) + (z1b >> 128) + carry;
r0 = x * y;
}}
uint constant b = 2 ** 128;
/// Divides 2 words by 1 word and returns 1 word results, reverting on overflow.
/// Based on `divlu` from Hacker's Delight.
function div_2_1_1(uint u1, uint u0, uint v) pure returns (uint q, uint r) { unchecked {
if (u1 >= v) revert Overflow();
uint s = nlz(v);
v = v << s;
uint vn1 = v >> 128;
uint vn0 = v & type(uint128).max;
uint un32 = (u1 << s) | (u0 >> 256 - s) & uint(-int(s) >> 255);
uint un10 = u0 << s;
uint un1 = un10 >> 128;
uint un0 = un10 & type(uint128).max;
uint q1 = un32/vn1;
uint rhat = un32 - q1*vn1;
while (q1 >= b || q1*vn0 > b*rhat + un1) {
q1 = q1 - 1;
rhat = rhat + vn1;
if (rhat >= b) break;
}
uint un21 = un32*b + un1 - q1*v;
uint q0 = un21/vn1;
rhat = un21 - q0*vn1;
while (q0 >= b || q0*vn0 > b*rhat + un0) {
q0 = q0 - 1;
rhat = rhat + vn1;
if (rhat >= b) break;
}
q = q1*b + q0;
r = (un21*b + un0 - q0*v) >> s;
}}
/// Counts leading zeroes.
/// Based on Hacker's Delight.
function nlz(uint x) pure returns (uint n) {
uint y;
n = 256;
y = x >> 128; if (y != 0) { n = n - 128; x = y; }
y = x >> 64; if (y != 0) { n = n - 64; x = y; }
y = x >> 32; if (y != 0) { n = n - 32; x = y; }
y = x >> 16; if (y != 0) { n = n - 16; x = y; }
y = x >> 8; if (y != 0) { n = n - 8; x = y; }
y = x >> 4; if (y != 0) { n = n - 4; x = y; }
y = x >> 2; if (y != 0) { n = n - 2; x = y; }
y = x >> 1; if (y != 0) return n - 2;
return n - x;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment