Created
July 8, 2016 12:10
-
-
Save tel/f93dafa81b11077125577c91eb05ca31 to your computer and use it in GitHub Desktop.
Scala Nimbers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.language.implicitConversions | |
class Nimber(val value: BigInt) extends AnyVal { | |
def size: Int = { | |
val bitSize = | |
if (value == 1) 0 | |
else if (value <= 3) 1 | |
else value.bitLength - 1 | |
val logBits = | |
if (bitSize > 0) | |
Math.log(bitSize) / Math.log(2) | |
else | |
0 | |
val result = BigInt(0).flipBit(logBits.floor.toInt) | |
result.toInt | |
} | |
override def toString: String = | |
value.toString | |
def +(other: Nimber): Nimber = | |
Nimber(value ^ other.value) | |
def *(other: Nimber): Nimber = | |
Nimber.multiply(this, other) | |
def split: (Nimber, Nimber) = | |
split(size) | |
def split(at: Int): (Nimber, Nimber) = | |
(Nimber(value >> at), Nimber(value & ~(BigInt(-1) << at))) | |
def square: Nimber = { | |
if (value == 0) return this | |
if (value == 1) return this | |
val k = size | |
val ps = split(k) | |
val p1 = ps._1.square | |
val p2 = ps._2.square | |
Nimber.join(k, p1, p2 + (p1 * Nimber.bit(k - 1))) | |
} | |
} | |
object Nimber { | |
implicit def apply(value: BigInt): Nimber = | |
new Nimber(value) | |
implicit def apply(value: Int): Nimber = | |
Nimber(BigInt(value)) | |
def bit(position: Int): Nimber = | |
Nimber(BigInt(0).flipBit(position)) | |
def join(at: Int, left: Nimber, right: Nimber): Nimber = | |
Nimber((left.value << at) | right.value) | |
def multiply(a: Nimber, b: Nimber): Nimber = { | |
if (a.value == 0) return Nimber(0) | |
if (b.value == 0) return Nimber(0) | |
if (a.value == 1) return b | |
if (b.value == 1) return a | |
if (a.value == b.value) return a.square | |
val k = Math.max(a.size, b.size) | |
val as = a.split(k) | |
val bs = b.split(k) | |
val p1 = as._1 * bs._1 | |
val p2 = as._2 * bs._2 | |
join( | |
k, | |
p2 + (as._1 + as._2) * (bs._1 + bs._2), | |
p2 + (p1 * bit(k - 1)) | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment