Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save JadenGeller/f469776b12fed2d84d6c to your computer and use it in GitHub Desktop.
Save JadenGeller/f469776b12fed2d84d6c to your computer and use it in GitHub Desktop.
Swift Arbitrary Precision, Arbitrary Base Integers
// Use an integer literal to instantiate
let a: BigInt = 23559821412349283
// Or use a string literal if you want to start with a number that is greater than IntMax
let b: BigInt = "123456789876543234567876543234567876543"
// Perform arithmetic operations
let c: BigInt = 123456 + 321 // -> 12777
let d = c / 100 // -> 127
let e = d << d // -> 12700000000......(127 zeros)
let f: BigInt = 2 ** 2 ** 100 // -> 1606938044258990275541962092341162602522202993782792835301376
let g = (f % 2, f % 3, f % 4) // -> (0, 1, 0)
// Note that BigInt conforms to IntegerArithmeticType and a bunch of other protocols, so you can do stuff like
// for x in 0...b { println(x) }
// which will take forever and ever to run fyi.
// Huge decimal numbers isn't the only thing here.
// Turns out that BigInt is really just a typealias for Integer<Decimal>
// That's right, there are other bases too!
let h: Integer<Binary> = 1000110110101 + 1010101010110 // -> 10011100001011
let i: Integer<Binary> = h / 10001 + 100 * 1001 ** 10 // -> 11101011011
let j: Integer<Hexadecimal> = "ABCDEF" + 12345 // -> ACF134
let k = 2 * j - "A" * "B" * "-C3" + "E" // -> 15A3640
// That's pretty cool, right? -- There's more!
// Some Integer<T> can very easily be converted into some Integer<U>.
// That's right, the numbers can be automatically converted between bases!
let l: Integer<Decimal> = 13 // -> 13
let m: Integer<Binary> = l.convertBase() // -> 1101
let n: Integer<Hexadecimal> = m.convertBase() // -> D
// You can convert between any arbitrary defined base to any other base.
// The convertBase function automatically determines what base to convert to based off context,
// which is pretty damn cool! (This is a feature of Swift's called return value overloading and
// it is made possible by Swift's usage of unification to resolve types.)
// A few more example.
let o: Integer<Hexadecimal> = (m * 1001101).convertBase() + (l ** 2).convertBase() // -> 492
let p = o * j * h.convertBase() // -> 788B8EF8B438
// You might be wonering how each of these bases are implemented, if there are any more, and if you can easily add more.
// Well, the answer is pretty cool.
// Base is a protocol defined as...
protocol Base {
static var radix: UInt8 { get }
static var representations: [(UInt8, Character)] { get }
}
// That's right, to create a new base you simply need to specify the radix (decimal->10, binary->2, etc.)
// and the representations (aka how to draw each number; 10->"A" in hexadecimal, for example)
// Check out the implementations of the built in bases below:
struct Decimal : Base {
static let radix: UInt8 = 10
static let representations = Array(digitRepresentations(0..<10))
}
struct Binary : Base {
static let radix: UInt8 = 2
static let representations = Array(digitRepresentations(0..<2))
}
struct Hexadecimal : Base {
static let radix: UInt8 = 16
static let representations = Array(digitRepresentations(0..<10) + [10: "A", 11: "B", 12: "C", 13: "D", 14: "E", 15: "F"])
}
// Note that digitRepresentations is a function that creates a sequence of tuples (x,"x") for a range of numbers.
// We could have easily specified the representations another way (like with a dictionary).
// Consider using BigInt with my Swift fraction library in order to get express arbitrarily precise real values!
// https://gist.github.com/JadenGeller/5e80ebf32442acc62e8e
// Generators
struct WeakZip2<S0: SequenceType, S1: SequenceType> : SequenceType {
let s0: S0
let s1: S1
func generate() -> WeakZipGenerator2<S0.Generator, S1.Generator> {
return WeakZipGenerator2(e0: s0.generate(), e1: s1.generate())
}
}
struct WeakZipGenerator2<E0: GeneratorType, E1: GeneratorType> : GeneratorType {
var e0: E0
var e1: E1
mutating func next() -> (E0.Element?, E1.Element?)? {
let tuple = (e0.next(), e1.next())
return tuple.0 == nil && tuple.1 == nil ? nil : tuple
}
}
func weakZip<S0 : SequenceType, S1 : SequenceType>(s0: S0, s1: S1) -> WeakZip2<S0, S1> {
return WeakZip2(s0: s0, s1: s1)
}
struct ConcatenatedSequence<S0: SequenceType, S1: SequenceType where S0.Generator.Element == S1.Generator.Element> : SequenceType {
let s0: S0
let s1: S1
func generate() -> ConcatenatedGenerator2<S0.Generator, S1.Generator> {
return ConcatenatedGenerator2(e0: s0.generate(), e1: s1.generate())
}
}
struct ConcatenatedGenerator2<E0: GeneratorType, E1: GeneratorType where E0.Element == E1.Element> : GeneratorType {
var e0: E0
var e1: E1
mutating func next() -> E0.Element? {
return e0.next() ?? e1.next()
}
}
func +<S0: SequenceType, S1: SequenceType where S0.Generator.Element == S1.Generator.Element>(lhs: S0, rhs: S1) -> ConcatenatedSequence<S0, S1> {
return ConcatenatedSequence(s0: lhs, s1: rhs)
}
// Digit
extension Character {
init(digit: UInt8) {
self.init(String(digit))
}
}
func digitRepresentations(range: Range<UInt8>) -> Zip2<Range<UInt8>, [Character]> {
return zip(range, map(range, { n in Character(digit: n) }))
}
protocol Base {
static var radix: UInt8 { get }
static var representations: [(UInt8, Character)] { get }
}
struct Decimal : Base {
static let radix: UInt8 = 10
static let representations = Array(digitRepresentations(0..<10))
}
struct Binary : Base {
static let radix: UInt8 = 2
static let representations = Array(digitRepresentations(0..<2))
}
struct Hexadecimal : Base {
static let radix: UInt8 = 16
static let representations = Array(digitRepresentations(0..<10) + [10: "A", 11: "B", 12: "C", 13: "D", 14: "E", 15: "F"])
}
struct Digit<B : Base> : IntegerArithmeticType, Printable, IntegerLiteralConvertible, Hashable, BidirectionalIndexType {
private let backing: UInt8
init(_ value: UInt8) {
assert(value < B.radix, "Digit values must be less than the radix")
self.backing = value
}
init(_ representation: Character) {
if let digit = B.representations.filter({ (value, aRepresentation) in aRepresentation == representation }).first {
self.init(digit.0)
}
else {
fatalError("Digit value must be valid character for the base.")
}
}
init(integerLiteral value: UInt8) {
self.init(value)
}
static var zero: Digit<B> {
get {
return Digit(0)
}
}
private var compliment: UInt8 {
get {
return B.radix - backing
}
}
func successor() -> Digit<B> {
return Digit.addWithOverflow(self, Digit(1)).0
}
func predecessor() -> Digit<B> {
return Digit.subtractWithOverflow(self, Digit(1)).0
}
static func addWithDigitOverflow(lhs: Digit<B>, _ rhs: Digit<B>, carry: Digit<B>) -> (Digit<B>, overflow: Digit<B>) {
let sum = lhs.backing + rhs.backing + carry.backing
return (Digit(sum % B.radix), Digit(sum / B.radix))
}
static func subtractWithDigitUnderflow(lhs: Digit<B>, _ rhs: Digit<B>, borrow: Digit<B>) -> (Digit<B>, underflow: Digit<B>) {
let sum = lhs.backing + rhs.compliment + borrow.compliment
var digitValue = sum % B.radix
if digitValue < 0 { digitValue += B.radix }
var underflow = sum / B.radix // overflow
if underflow >= 2 { underflow -= 2 } // underflow = abs(overflow - 2)
else { underflow = 2 - underflow }
return (Digit(digitValue), Digit(underflow))
}
static func multiplyWithDigitOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Digit<B>) {
let product = lhs.backing * rhs.backing
return (Digit(product % B.radix), Digit(product / B.radix))
}
static func addWithOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Bool) {
let (digit, overflow) = addWithDigitOverflow(lhs, rhs, carry: Digit.zero)
return (digit, !overflow.isZero)
}
static func subtractWithOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Bool) {
let (digit, overflow) = subtractWithDigitUnderflow(lhs, rhs, borrow: Digit.zero)
return (digit, !overflow.isZero)
}
static func multiplyWithOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Bool) {
let (digit, overflow): (Digit<B>, Digit<B>) = multiplyWithDigitOverflow(lhs, rhs)
return (digit, !overflow.isZero)
}
static func divideWithOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Bool) {
let quotient = lhs.backing / rhs.backing
return (Digit(quotient % B.radix), false)
}
static func remainderWithOverflow(lhs: Digit<B>, _ rhs: Digit<B>) -> (Digit<B>, overflow: Bool) {
let quotient = lhs.backing % rhs.backing
return (Digit(quotient % B.radix), false)
}
var isZero: Bool {
get {
return self.backing == 0
}
}
static func representationForValue(value: UInt8) -> Character {
return B.representations.filter({ (aValue, representation) in aValue == value }).first!.1
}
var description: String {
get {
return String(Digit.representationForValue(self.backing))
}
}
var hashValue: Int {
get {
return Int(backing)
}
}
func toIntMax() -> IntMax {
return IntMax(backing)
}
}
func -<B>(lhs: Digit<B>, rhs: Digit<B>) -> Digit<B> {
return Digit<B>.subtractWithOverflow(lhs, rhs).0
}
func +<B>(lhs: Digit<B>, rhs: Digit<B>) -> Digit<B> {
return Digit<B>.addWithOverflow(lhs, rhs).0
}
func *<B>(lhs: Digit<B>, rhs: Digit<B>) -> Digit<B> {
return Digit<B>.multiplyWithOverflow(lhs, rhs).0
}
func /<B>(lhs: Digit<B>, rhs: Digit<B>) -> Digit<B> {
return Digit<B>.divideWithOverflow(lhs, rhs).0
}
func %<B>(lhs: Digit<B>, rhs: Digit<B>) -> Digit<B> {
return Digit<B>.remainderWithOverflow(lhs, rhs).0
}
func ==<B>(lhs: Digit<B>, rhs: Digit<B>) -> Bool {
return lhs.backing == rhs.backing
}
func <<B>(lhs: Digit<B>, rhs: Digit<B>) -> Bool {
return lhs.backing < rhs.backing
}
// Integer
enum IntegerSign {
case Positive
case Negative
var opposite: IntegerSign {
switch self {
case .Positive: return .Negative
case .Negative: return .Positive
}
}
static func multiply(lhs: IntegerSign, rhs: IntegerSign) -> IntegerSign {
return lhs == rhs ? .Positive : .Negative
}
}
struct Integer<B : Base> : Printable, DebugPrintable, IntegerLiteralConvertible, Strideable, SignedNumberType, SequenceType, IntegerArithmeticType, Hashable, BidirectionalIndexType, StringLiteralConvertible {
private let backing: [Digit<B>]
let sign: IntegerSign
var magnitude: Integer<B> {
get {
return Integer(backing: backing, sign: .Positive)
}
}
var isNegative: Bool {
get {
return sign == .Negative && !isZero
}
}
var isPositive: Bool {
get {
return sign == .Positive && !isZero
}
}
static var radix: Integer<B> {
get {
return Integer(IntMax(B.radix))
}
}
init(integerLiteral value: IntMax) {
self.init(value)
}
init(value: IntMax) {
self.init(value)
}
init(var decimalValue value: IntMax) {
let radix = IntMax(B.radix)
let sign: IntegerSign = value < 0 ? .Negative : .Positive
value = abs(value)
var num = ""
while value > 0 {
let digit = Digit<B>.representationForValue(UInt8(value % radix))
num.insert(digit, atIndex: num.startIndex)
value = value / radix
}
self.init(string: num, sign: sign)
}
init(string: String, sign: IntegerSign) {
self.init(backing: Array(reverse(string).map({ num in Digit(num) })), sign: sign)
}
init(var stringLiteral value: String) {
if value[value.startIndex] == "-" {
value.removeAtIndex(value.startIndex)
self.init(string: value, sign: .Negative)
}
else {
self.init(string: value, sign: .Positive)
}
}
init(extendedGraphemeClusterLiteral value: String) {
self.init(stringLiteral: value)
}
init(unicodeScalarLiteral value: String) {
self.init(stringLiteral: value)
}
init(var _ value: IntMax) {
// Radix 10 because the user types all numbers (even those in other bases) in decimal
// as interpreted by Swift
self.init(string: String(abs(value), radix: 10), sign: value < 0 ? .Negative : .Positive)
}
private init(var backing: [Digit<B>], sign: IntegerSign) {
// Remove leading zeros
while let msb = backing.last where msb.isZero { backing.removeLast() }
self.backing = backing
self.sign = sign
}
func advancedBy(value: Integer<B>) -> Integer<B> {
return Integer.addWithOverflow(self, value).0
}
func distanceTo(other: Integer<B>) -> Integer<B> {
return Integer.subtractWithOverflow(other, self).0
}
func successor() -> Integer<B> {
return self.advancedBy(Integer(1))
}
func predecessor() -> Integer<B> {
return self.advancedBy(Integer(-1))
}
func shift(count: Integer<B>) -> Integer<B> {
if count.isPositive { return shiftRight(count) }
else if count.isNegative { return shiftLeft(count.negated) }
else { return self }
}
func shiftRight(count: Integer<B>) -> Integer<B> {
assert(count > 0, "Cannot shift right by a non-positive count")
var backing = self.backing
for _ in 1...count {
if backing.count == 0 { break }
backing.removeAtIndex(0)
}
return Integer(backing: backing, sign: sign)
}
func shiftLeft(count: Integer<B>) -> Integer<B> {
assert(count > 0, "Cannot shift left by a non-positive count")
var backing = self.backing
for _ in 1...count {
backing.insert(Digit.zero, atIndex: 0)
}
return Integer(backing: backing, sign: sign)
}
var negated: Integer<B> {
get {
return Integer(backing: backing, sign: sign.opposite)
}
}
var isZero: Bool {
get {
return backing.reduce(true, combine: { lhs, rhs in lhs && rhs.isZero })
}
}
var description: String {
get {
var str = backing.reduce("", combine: { lhs, rhs in rhs.description + lhs })
if str == "" { str = "0" }
else if isNegative { str = "-" + str }
return str
}
}
var debugDescription: String {
get {
return description
}
}
var hashValue: Int {
get {
return reduce(self, 0, { total, digit in total &+ digit.hashValue })
}
}
func generate() -> GeneratorOf<Digit<B>> {
return GeneratorOf(backing.generate())
}
func toIntMax() -> IntMax {
fatalError("To implement")
}
static func addWithOverflow(lhs: Integer<B>, _ rhs: Integer<B>) -> (Integer<B>, overflow: Bool) {
switch (lhs.isNegative, rhs.isNegative) {
case (true, true): return (addWithOverflow(rhs.negated, lhs.negated).0.negated, false)
case (false, true): return (subtractWithOverflow(lhs, rhs.negated).0, false)
case (true, false): return (subtractWithOverflow(rhs, lhs.negated).0, false)
default:
if rhs.isNegative { return (subtractWithOverflow(lhs, rhs.negated).0, false) }
var backing = [Digit<B>]()
var carry = Digit<B>.zero
for (l, r) in weakZip(lhs, rhs) {
if l == nil && r == nil { break }
let (digit, newCarry) = Digit.addWithDigitOverflow(l ?? Digit.zero, r ?? Digit.zero, carry: carry)
carry = newCarry
backing.append(digit)
}
if !carry.isZero { backing.append(carry) }
return (Integer(backing: backing, sign: .Positive), false)
}
}
static func subtractWithOverflow(lhs: Integer<B>, _ rhs: Integer<B>) -> (Integer<B>, overflow: Bool) {
if rhs.isNegative { return (addWithOverflow(lhs, rhs.negated).0, false) }
if lhs < rhs { return (subtractWithOverflow(rhs, lhs).0.negated, false) }
var backing = [Digit<B>]()
var borrow = Digit<B>.zero
for (l, r) in weakZip(lhs, rhs) {
if l == nil && r == nil { break }
let (digit, newBorrow) = Digit.subtractWithDigitUnderflow(l ?? Digit.zero, r ?? Digit.zero, borrow: borrow)
borrow = newBorrow
backing.append(digit)
}
return (Integer(backing: backing, sign: .Positive), false)
}
private static func multiply(lhs: Integer<B>, _ rhs: Digit<B>) -> Integer<B> {
var backing = [Digit<B>]()
var carry = Digit<B>.zero
for n in lhs {
let (rawDigit, newCarryMutiplication) = Digit.multiplyWithDigitOverflow(n, rhs)
let (digit, newCarryAddition) = Digit.addWithDigitOverflow(rawDigit, carry, carry: 0)
carry = newCarryMutiplication + newCarryAddition
backing.append(digit)
}
if !carry.isZero { backing.append(carry) }
return Integer(backing: backing, sign: .Positive)
}
static func multiplyWithOverflow(lhs: Integer<B>, _ rhs: Integer<B>) -> (Integer<B>, overflow: Bool) {
var total: Integer = 0
for (index, digit) in enumerate(rhs) {
if digit == 0 { continue } // optimization
total += multiply(lhs, digit) << Integer(decimalValue: IntMax(index))
}
return (Integer(backing: total.backing, sign: IntegerSign.multiply(lhs.sign, rhs: rhs.sign)), false)
}
static func slowDivide(var numerator: Integer<B>, _ denominator: Integer<B>) -> (quotient: Integer<B>, remainder: Integer<B>) {
var count: Integer<B> = 0
while !numerator.isNegative {
numerator -= denominator
count++
}
// We went over once it became negative, so backtrack
count--
numerator += denominator
return (count, numerator)
}
static func longDivide(numerator: Integer<B>, _ denominator: Integer<B>) -> (quotient: Integer<B>, remainder: Integer<B>) {
var end = numerator.backing.endIndex
var start = end - denominator.backing.count
while start >= numerator.backing.startIndex {
let range = start..<end
let smallNum = Integer(backing: Array(numerator.backing[range]), sign: .Positive)
let (quotient, remainder) = slowDivide(smallNum, denominator)
if !quotient.isZero {
// We can divide!
var numeratorBacking = numerator.backing
numeratorBacking[range] = remainder.backing[0..<(remainder.backing.count)]
let newNumerator = Integer(backing: numeratorBacking, sign: .Positive)
let value = quotient << Integer(decimalValue: IntMax(start))
if newNumerator.isZero {
return (value, 0)
}
else {
let (recursiveQuotient, recursiveRemainder) = longDivide(newNumerator, denominator)
return (value + recursiveQuotient, recursiveRemainder)
}
}
start--
}
return (0, numerator)
}
static func divideWithOverflow(lhs: Integer<B>, _ rhs: Integer<B>) -> (Integer<B>, overflow: Bool) {
if rhs.isZero { fatalError("division by zero") }
return (longDivide(lhs, rhs).quotient, false)
}
static func remainderWithOverflow(lhs: Integer<B>, _ rhs: Integer<B>) -> (Integer<B>, overflow: Bool) {
if rhs.isZero { fatalError("division by zero") }
return (longDivide(lhs, rhs).remainder, false)
}
func convertBase<T>() -> Integer<T> {
var num: Integer<T> = 0
var muliplier: Integer<T> = 1
var radix = Integer<T>(decimalValue: IntMax(B.radix))
for digit in backing {
num += Integer<T>(decimalValue: IntMax(digit.backing)) * muliplier
muliplier *= radix
}
return num
}
}
func ==<B>(lhs: Integer<B>, rhs: Integer<B>) -> Bool {
return lhs.backing == rhs.backing
}
func <<B>(lhs: Integer<B>, rhs: Integer<B>) -> Bool {
switch (lhs.isNegative, rhs.isNegative) {
case (true, true): return !(lhs.negated < rhs.negated)
case (true, false): return true
case (false, true): return false
default:
if lhs.backing.count < rhs.backing.count { return true }
else if lhs.backing.count > rhs.backing.count { return false }
else {
// Same length; compare digits
for (l, r) in zip(reverse(lhs.backing), reverse(rhs.backing)) {
if l < r { return true }
else if l > r { return false }
}
return false
}
}
}
prefix func -<B>(value: Integer<B>) -> Integer<B> {
return value.negated
}
func -<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return Integer.subtractWithOverflow(lhs, rhs).0
}
func +<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return Integer.addWithOverflow(lhs, rhs).0
}
func *<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return Integer.multiplyWithOverflow(lhs, rhs).0
}
func /<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return Integer.divideWithOverflow(lhs, rhs).0
}
func %<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return Integer.remainderWithOverflow(lhs, rhs).0
}
func <<<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return lhs.shift(-rhs)
}
func >><B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
return lhs.shift(rhs)
}
func -=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs - rhs
}
func +=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs + rhs
}
func *=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs * rhs
}
func /=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs / rhs
}
func %=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs % rhs
}
func <<=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs << rhs
}
func >>=<B>(inout lhs: Integer<B>, rhs: Integer<B>) {
lhs = lhs >> rhs
}
prefix func --<B>(inout value: Integer<B>) -> Integer<B> {
value -= 1
return value
}
prefix func ++<B>(inout value: Integer<B>) -> Integer<B> {
value += 1
return value
}
postfix func --<B>(inout value: Integer<B>) -> Integer<B> {
let oldValue = value
value -= 1
return oldValue
}
postfix func ++<B>(inout value: Integer<B>) -> Integer<B> {
let oldValue = value
value += 1
return oldValue
}
infix operator ** {
associativity left
precedence 150
}
func **<B>(lhs: Integer<B>, rhs: Integer<B>) -> Integer<B> {
var result: Integer<B> = 1
for i in 0..<abs(rhs) {
result *= lhs
}
return rhs >= 0 ? result : 1 / result
}
typealias BigInt = Integer<Decimal>
@axman6
Copy link

axman6 commented Sep 25, 2016

Your result for 2 ** 2 ** 100 are somewhat confusing, is ** really left associative? The result shown is 4^100, not 2^(2^100) as I think most people would expect. For reference, Haskell has ^, ^^ and ** as right associative.

@axman6
Copy link

axman6 commented Sep 25, 2016

Also, looks like you could do with using the O(log n) implementation of ** instead of the linear one. Haskell's implementation:

x ^ 0            = 1
x ^ n | n > 0    = f x (n-1) x
  where
    f _ 0 y = y
    f x n y = g x n
        where g x n | even n    = g (x*x) (n `quot` 2)
                    | otherwise = f x (n-1) (x*y)
_ ^ _            = error "Prelude.^: negative exponent"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment