Created
June 30, 2021 19:32
-
-
Save runarorama/a933af7794ae40d103231a65652314db to your computer and use it in GitHub Desktop.
Arbitrary-precision naturals in Unison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
unique type Natural | |
= Natural (Nonempty Nat) | |
List.Nonempty.appendList : Nonempty a -> [a] -> Nonempty a | |
List.Nonempty.appendList = cases | |
Nonempty a as -> | |
use List ++ | |
bs -> Nonempty a (as ++ bs) | |
List.dropRightWhile : (a ->{g} Boolean) -> [a] ->{g} [a] | |
List.dropRightWhile p as = | |
go vs = | |
match List.last vs with | |
Some a | p a -> List.join (Optional.toList (List.init vs)) | |
a -> vs | |
go as | |
List.unsafeNonempty : [t] -> Nonempty t | |
List.unsafeNonempty = cases | |
[] -> bug "empty list!" | |
x +: xs -> x +| xs | |
(Natural.*) : Natural -> Natural -> Natural | |
u Natural.* v = | |
use List size | |
use Nat + == | |
use Nonempty toList | |
b = radix | |
us = toList (digits u) | |
vs = toList (digits v) | |
m = size us | |
n = size vs | |
m6 j ws = if j < n then m2 j ws else dropRightWhile (x -> x == 0) ws | |
m2 j ws = | |
vj = if n > j then unsafeAt j vs else 0 | |
if vj == 0 then m6 (j + 1) (ws :+ 0) else m4 0 j 0 vj ws | |
m4 i j k vj ws = | |
use List ++ | |
use Nat drop | |
ui = if m > i then unsafeAt i us else 0 | |
t = | |
use Nat * | |
ui * vj + (if size ws > (i + j) then unsafeAt (i + j) ws else 0) + k | |
ws' = | |
replace | |
(i + j) | |
(Nat.mod t b) | |
(if size ws > (i + j) then ws | |
else ws ++ fill (drop (size ws) (i + j)) 0) | |
k' = | |
use Nat / | |
t / b | |
i' = i + 1 | |
if i' < m then m4 i' j k' vj ws' | |
else | |
ws'' = | |
replace | |
(m + j) | |
k' | |
(if size ws' > (m + j) then ws' | |
else ws' ++ fill (drop (size ws') (m + j)) 0) | |
m6 (j + 1) ws'' | |
Natural.internal.normalize (Natural (unsafeNonempty (m2 0 (fill m 0)))) | |
(Natural.+) : Natural -> Natural -> Natural | |
u Natural.+ v = | |
use List size | |
use Nonempty toList | |
b = radix | |
us = toList (digits u) | |
vs = toList (digits v) | |
uz = size us | |
vz = size vs | |
n = Universal.max uz vz | |
"j runs through digit positions" | |
"k keeps track of carries at each step" | |
go j k ws = | |
if j >= n then if k > 0 then ws :+ k else ws | |
else | |
use Nat + | |
uj = if uz > j then unsafeAt j us else 0 | |
vj = if vz > j then unsafeAt j vs else 0 | |
"Add the digits at position j, plus carry" | |
a = uj + vj + k | |
"The new digit at j is the remainder of dividing by the radix" | |
wj = Nat.mod a b | |
"We carry if necessary" | |
k' = if a >= b then 1 else 0 | |
"Go to the next digit" | |
go (j + 1) k' (ws :+ wj) | |
Natural (unsafeNonempty (go 0 0 [])) | |
Natural.fromNat : Nat -> Natural | |
Natural.fromNat u = | |
b = radix | |
go n ws = | |
use Nat / == | |
x = n / b | |
ws' = ws :+ Nat.mod n b | |
if x == 0 then ws' else go (n / b) ws' | |
Natural (unsafeNonempty (go u [])) | |
Natural.fromNats : [Nat] -> Natural | |
Natural.fromNats nats = Natural.internal.normalize (mkNatural nats) | |
Natural.internal.digits : Natural -> Nonempty Nat | |
Natural.internal.digits = cases Natural ns -> ns | |
Natural.internal.mkNatural : [Nat] -> Natural | |
Natural.internal.mkNatural nats = | |
match List.nonempty nats with | |
None -> Natural (Nonempty.singleton 0) | |
Some ns -> Natural ns | |
Natural.internal.normalize : Natural -> Natural | |
Natural.internal.normalize = | |
lmask = 2251799813685247 | |
hmask = Nat.complement lmask | |
cases | |
Natural ns -> | |
go rem next done carry = | |
newNext = | |
use Nat + | |
next + carry | |
if newNext <= lmask then rec rem (done :+ newNext) 0 | |
else | |
use Nat and | |
newNewNext = and lmask newNext | |
newCarry = Nat.shiftRight (and hmask newNext) 51 | |
rec rem (done :+ newNewNext) newCarry | |
rec rem done carry = | |
match rem with | |
[] -> | |
use Nat == | |
done' = dropRightWhile (x -> x === 0) done | |
if carry == 0 then mkNatural done' else mkNatural (carry +: done') | |
xs :+ x -> go xs x done carry | |
go (Nonempty.tail ns) (Nonempty.head ns) [] 0 | |
Natural.internal.radix : Nat | |
Natural.internal.radix = 2251799813685248 | |
Natural.tests.additionAssociative : [Result] | |
Natural.tests.additionAssociative = | |
use Natural + | |
runs 100 '(laws.associative tests.gen (+)) | |
Natural.tests.additionCommutative : [Result] | |
Natural.tests.additionCommutative = | |
use Natural + | |
runs 100 '(laws.commutative tests.gen (+)) | |
Natural.tests.additionZero : [Result] | |
Natural.tests.additionZero = | |
use Natural + fromNat | |
runs | |
100 'let | |
a = !tests.gen | |
expect ((a + fromNat 0 === a) && (fromNat 0 + a === a)) | |
Natural.tests.gen : '{Gen} Natural | |
Natural.tests.gen = | |
use Natural fromNat | |
'let | |
naturals n r = | |
use Natural + | |
yield n | |
<|> weight | |
1 | |
'(yield r <|> weight 1 '(naturals (n + fromNat 1) (r + fromNat 1))) | |
Gen.sample (naturals (fromNat 0) (fromNat (Nat.drop radix 1))) | |
Natural.tests.multiplicationCommutative : [Result] | |
Natural.tests.multiplicationCommutative = | |
use Natural * | |
runs 100 '(laws.commutative tests.gen (*)) | |
Natural.tests.multiplicationZero : [Result] | |
Natural.tests.multiplicationZero = | |
use Natural * fromNat | |
runs | |
100 'let | |
n = !tests.gen | |
expect (n * fromNat 0 === fromNat 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment