Last active
September 26, 2022 21:57
-
-
Save frangio/2cfcda1ce297dc04b74d31762452b157 to your computer and use it in GitHub Desktop.
Multiply-and-divide without intermediate overflow, without particularly clever algorithms.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SPDX-License-Identifier: MIT | |
pragma solidity ^0.8.0; | |
error Overflow(); | |
/// Multiply-and-divide without intermediate overflow. | |
/// @return r = (x * y) / d | |
function mulDiv(uint x, uint y, uint d) pure returns (uint r) { unchecked { | |
(uint z1, uint z0) = mul_1_1_2(x, y); | |
(r,) = div_2_1_1(z1, z0, d); | |
}} | |
/// Multiplies 1 word by 1 word and returns 2 word result. No overflow. | |
function mul_1_1_2(uint x, uint y) pure returns (uint r1, uint r0) { unchecked { | |
(uint x1, uint x0) = (x >> 128, x & type(uint128).max); | |
(uint y1, uint y0) = (y >> 128, y & type(uint128).max); | |
uint z2 = x1 * y1; | |
uint z1a = x1 * y0; | |
uint z1b = x0 * y1; | |
uint z0 = x0 * y0; | |
uint carry = ((z1a & type(uint128).max) + (z1b & type(uint128).max) + (z0 >> 128)) >> 128; | |
r1 = z2 + (z1a >> 128) + (z1b >> 128) + carry; | |
r0 = x * y; | |
}} | |
uint constant b = 2 ** 128; | |
/// Divides 2 words by 1 word and returns 1 word results, reverting on overflow. | |
/// Based on `divlu` from Hacker's Delight. | |
function div_2_1_1(uint u1, uint u0, uint v) pure returns (uint q, uint r) { unchecked { | |
if (u1 >= v) revert Overflow(); | |
uint s = nlz(v); | |
v = v << s; | |
uint vn1 = v >> 128; | |
uint vn0 = v & type(uint128).max; | |
uint un32 = (u1 << s) | (u0 >> 256 - s) & uint(-int(s) >> 255); | |
uint un10 = u0 << s; | |
uint un1 = un10 >> 128; | |
uint un0 = un10 & type(uint128).max; | |
uint q1 = un32/vn1; | |
uint rhat = un32 - q1*vn1; | |
while (q1 >= b || q1*vn0 > b*rhat + un1) { | |
q1 = q1 - 1; | |
rhat = rhat + vn1; | |
if (rhat >= b) break; | |
} | |
uint un21 = un32*b + un1 - q1*v; | |
uint q0 = un21/vn1; | |
rhat = un21 - q0*vn1; | |
while (q0 >= b || q0*vn0 > b*rhat + un0) { | |
q0 = q0 - 1; | |
rhat = rhat + vn1; | |
if (rhat >= b) break; | |
} | |
q = q1*b + q0; | |
r = (un21*b + un0 - q0*v) >> s; | |
}} | |
/// Counts leading zeroes. | |
/// Based on Hacker's Delight. | |
function nlz(uint x) pure returns (uint n) { | |
uint y; | |
n = 256; | |
y = x >> 128; if (y != 0) { n = n - 128; x = y; } | |
y = x >> 64; if (y != 0) { n = n - 64; x = y; } | |
y = x >> 32; if (y != 0) { n = n - 32; x = y; } | |
y = x >> 16; if (y != 0) { n = n - 16; x = y; } | |
y = x >> 8; if (y != 0) { n = n - 8; x = y; } | |
y = x >> 4; if (y != 0) { n = n - 4; x = y; } | |
y = x >> 2; if (y != 0) { n = n - 2; x = y; } | |
y = x >> 1; if (y != 0) return n - 2; | |
return n - x; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment