Created
August 11, 2022 14:00
-
-
Save demotomohiro/704e6110919cf163de18ab36b8bbef4c to your computer and use it in GitHub Desktop.
Fast approximated BigInt sqrt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import std/[math, options] | |
import bigints | |
func sqrt(a: BigInt): Option[BigInt] = | |
if a < 0'bi: | |
none(BigInt) | |
else: | |
# Approximate BigInt as z = x * 2^(2y) | |
# (0 ≦ x < 2^62, y ∈ N) | |
# Then sqrt(z) = sqrt(x) * 2^y. | |
# | |
# Take highest 62 bits, convert it to int64 and float. | |
# Calculate square root of it using math.sqrt. | |
# Convert it to int64 and convert it to BigInt | |
let | |
log2 = a.fastLog2 and (not 1) | |
shift = max(log2 - 60, 0) | |
# Need to multiply 2^shiftf to the result of sqrt | |
# before converting it to int64 | |
# to avoid losing precision. | |
shiftf = min(shift div 2, 30) | |
shiftBack = (shift div 2) - shiftf | |
head = (a shr shift).toInt[:int64]().get | |
sq = (sqrt(head.float) * pow(2.0, shiftf.float)).int64 | |
some(sq.initBigInt shl shiftBack) | |
proc test = | |
doAssert sqrt(4'bi).get == 2'bi | |
doAssert sqrt(10'bi).get == 3'bi | |
doAssert sqrt(400'bi).get == 20'bi | |
doAssert not sqrt(-400'bi).isSome | |
doAssert sqrt(4000'bi).get == 63'bi | |
doAssert sqrt(12345'bi).get == 111'bi | |
doAssert sqrt(12345678987654321'bi).get == 111111111'bi | |
doAssert sqrt(0xfff_ffff_ffff_ffff'bi).get == 1073741824'bi | |
doAssert sqrt(0x4_0000_0000_0000_0000'bi).get == 0x2_0000_0000'bi | |
doAssert sqrt(0x8_0000_0000_0000_0000'bi).get == (sqrt(8.0) * pow(16.0, 8)).int.initBigInt | |
doAssert sqrt(0x9_0000_0000_0000_0000'bi).get == 0x3_0000_0000.int.initBigInt | |
doAssert sqrt(0xc_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 8)).int.initBigInt | |
doAssert sqrt(0x1_0000_0000_0000_0000_0000'bi).get == 0x100_0000_0000'bi | |
doAssert sqrt(0x2_0000_0000_0000_0000_0000'bi).get == (sqrt(2.0) * pow(16.0, 10)).int.initBigInt | |
doAssert sqrt(0x4_0000_0000_0000_0000_0000'bi).get == 0x200_0000_0000'bi | |
doAssert sqrt(0x9_0000_0000_0000_0000_0000'bi).get == 0x300_0000_0000'bi | |
doAssert sqrt(0xc_0000_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 10)).int.initBigInt | |
doAssert sqrt(0x1_0000_0000_0000_0000_0000_0000'bi).get == 0x1_0000_0000_0000'bi | |
doAssert sqrt(0x2_0000_0000_0000_0000_0000_0000'bi).get == (sqrt(2.0) * pow(16.0, 12)).int.initBigInt | |
doAssert sqrt(0x4_0000_0000_0000_0000_0000_0000'bi).get == 0x2_0000_0000_0000'bi | |
doAssert sqrt(0x9_0000_0000_0000_0000_0000_0000'bi).get == 0x3_0000_0000_0000'bi | |
doAssert sqrt(0xc_0000_0000_0000_0000_0000_0000'bi).get == (2.0 * sqrt(3.0) * pow(16.0, 12)).int.initBigInt | |
doAssert sqrt(pow(2'bi, 1000)).get == pow(2'bi, 500) | |
doAssert sqrt(4'bi * pow(2'bi, 1000)).get == pow(2'bi, 501) | |
doAssert sqrt(9'bi * pow(2'bi, 1000)).get == 3'bi * pow(2'bi, 500) | |
doAssert sqrt(25'bi * pow(2'bi, 1000)).get == 5'bi * pow(2'bi, 500) | |
doAssert sqrt(pow(2'bi, 10000)).get == pow(2'bi, 5000) | |
doAssert sqrt(4'bi * pow(2'bi, 10000)).get == 2'bi * pow(2'bi, 5000) | |
doAssert sqrt(9'bi * pow(2'bi, 10000)).get == 3'bi * pow(2'bi, 5000) | |
doAssert sqrt(16'bi * pow(2'bi, 10000)).get == 4'bi * pow(2'bi, 5000) | |
doAssert sqrt(25'bi * pow(2'bi, 10000)).get == 5'bi * pow(2'bi, 5000) | |
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10000)).get == 111111111'bi * pow(2'bi, 5000) | |
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10002)).get == 111111111'bi * pow(2'bi, 5001) | |
doAssert sqrt(12345678987654321'bi * pow(2'bi, 10004)).get == 111111111'bi * pow(2'bi, 5002) | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment