Skip to content

Instantly share code, notes, and snippets.

@lifthrasiir
Created April 18, 2011 20:37
Show Gist options
  • Save lifthrasiir/926113 to your computer and use it in GitHub Desktop.
Save lifthrasiir/926113 to your computer and use it in GitHub Desktop.
2^(10^k)
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdbool.h>
#include <string.h>
#include <math.h>
typedef uint32_t digit; // should be able to represent 0..3*BASE-1
typedef uint64_t digit2; // should be as large as maxdigit + maxdigit * maxdigit
#define BASE 1000000000u
#define DIGITSIZE 9
#define DIGITFMT " %09d"
#define FIRSTDIGITFMT "%d"
struct num {
bool negative;
int ndigits;
digit *digits;
};
void destroy(struct num *x) {
free(x->digits);
}
const struct num *truncate(struct num *x) {
while (x->ndigits > 0 && x->digits[x->ndigits-1] == 0) --(x->ndigits);
return x;
}
const struct num *subtract(const struct num *x, const struct num *y, struct num *z);
const struct num *add(const struct num *x, const struct num *y, struct num *z) {
if (x->negative && !y->negative) {
struct num minusx = {.negative=false, .ndigits=x->ndigits, .digits=x->digits};
return subtract(y, &minusx, z);
} else if (!x->negative && y->negative) {
struct num minusy = {.negative=false, .ndigits=y->ndigits, .digits=y->digits};
return subtract(x, &minusy, z);
}
if (x->ndigits < y->ndigits) {
const struct num *t = x; x = y; y = t;
}
int xn = x->ndigits, yn = y->ndigits;
digit *xv = x->digits, *yv = y->digits, *zv;
int i;
z->negative = x->negative; // == y->negative
z->ndigits = xn + 1;
zv = z->digits = malloc(z->ndigits * sizeof(digit));
digit carry = 0;
for (i = 0; i < yn; ++i) {
zv[i] = xv[i] + yv[i] + carry;
if (zv[i] >= BASE) {
zv[i] -= BASE;
carry = 1;
} else {
carry = 0;
}
}
for (; i < xn; ++i) {
zv[i] = xv[i] + carry;
if (zv[i] >= BASE) {
zv[i] -= BASE;
carry = 1;
} else {
carry = 0;
}
}
zv[xn] = carry;
return truncate(z);
}
const struct num *subtract(const struct num *x, const struct num *y, struct num *z) {
if (y->negative) {
struct num minusy = {.negative=false, .ndigits=y->ndigits, .digits=y->digits};
if (x->negative) {
struct num minusx = {.negative=false, .ndigits=x->ndigits, .digits=x->digits};
return subtract(&minusy, &minusx, z); // tail recursive, maybe?
} else {
return add(x, &minusy, z);
}
} else if (x->negative) {
struct num minusy = {.negative=true, .ndigits=y->ndigits, .digits=y->digits};
return add(x, &minusy, z); // (-x)+(-y) is built-in to add
}
// make x->ndigits >= y->ndigits at least, even when x < y
bool swapped = (x->ndigits < y->ndigits);
if (swapped) {
const struct num *t = x; x = y; y = t;
}
int xn = x->ndigits, yn = y->ndigits;
digit *xv = x->digits, *yv = y->digits, *zv;
int i;
z->negative = swapped;
z->ndigits = xn;
zv = z->digits = malloc(z->ndigits * sizeof(digit));
digit borrow = BASE;
for (i = 0; i < yn; ++i) {
zv[i] = xv[i] + borrow - yv[i]; // avoiding underflow
if (zv[i] < BASE) {
borrow = BASE-1;
} else {
zv[i] -= BASE;
borrow = BASE;
}
}
for (; i < xn; ++i) {
zv[i] = xv[i] + borrow;
if (zv[i] < BASE) {
borrow = BASE-1;
} else {
zv[i] -= BASE;
borrow = BASE;
}
}
if (borrow < BASE) { // negative
z->negative = !z->negative;
borrow = BASE;
for (i = 0; i < xn; ++i) {
zv[i] = borrow - zv[i];
if (zv[i] < BASE) {
borrow = BASE-1;
} else {
zv[i] -= BASE;
borrow = BASE;
}
}
}
return truncate(z);
}
static const digit BASEMULTIPLES[] = {0, BASE, BASE*2, BASE*3};
const struct num *divideby2(struct num *x) {
digit *v = x->digits;
int i;
digit modulo = 0;
for (i = x->ndigits - 1; i >= 0; --i) {
v[i] += modulo;
modulo = BASEMULTIPLES[v[i] & 1];
v[i] >>= 1;
}
return x;
}
const struct num *divideby3(struct num *x) {
digit *v = x->digits;
int i;
digit modulo = 0;
for (i = x->ndigits - 1; i >= 0; --i) {
v[i] += modulo;
modulo = BASEMULTIPLES[v[i] % 3];
v[i] /= 3;
}
return x;
}
const struct num *multiply(const struct num *x, const struct num *y, struct num *z);
const struct num *multiply_toom3(const struct num *x, const struct num *y, struct num *z) {
int xn = x->ndigits, yn = y->ndigits;
digit *xv = x->digits, *yv = y->digits, *zv;
int i, j;
int b = ((xn > yn ? xn : yn) + 2) / 3; // block size
// abusing the internal memory layout
struct num x0 = {.ndigits=0, .digits=NULL}, y0 = {.ndigits=0, .digits=NULL};
struct num x1 = {.ndigits=0, .digits=NULL}, y1 = {.ndigits=0, .digits=NULL};
struct num x2 = {.ndigits=0, .digits=NULL}, y2 = {.ndigits=0, .digits=NULL};
if (xn > 0) { x0.ndigits = (xn < b ? xn : b); x0.digits = xv; }
if (xn > b) { x1.ndigits = (xn < b*2 ? xn-b : b); x1.digits = xv + b; }
if (xn > b*2) { x2.ndigits = (xn < b*3 ? xn-b*2 : b); x2.digits = xv + b*2; }
if (yn > 0) { y0.ndigits = (yn < b ? yn : b); y0.digits = yv; }
if (yn > b) { y1.ndigits = (yn < b*2 ? yn-b : b); y1.digits = yv + b; }
if (yn > b*2) { y2.ndigits = (yn < b*3 ? yn-b*2 : b); y2.digits = yv + b*2; }
struct num px, py, pz, t1, t2;
struct num q0, q1, qm1, qm2, qinf;
struct num r[5];
// p(0) = m0
multiply(&x0, &y0, &q0);
// p(1) = (m0 + m2) + m1
add(&x0, &x2, &t1);
add(&y0, &y2, &t2);
add(&t1, &x1, &px);
add(&t2, &y1, &py);
multiply(&px, &py, &q1);
destroy(&px);
destroy(&py);
// p(-1) = (m0 + m2) - m1 (subexpression reused from p(1))
subtract(&t1, &x1, &px);
subtract(&t2, &y1, &py);
multiply(&px, &py, &qm1);
destroy(&t1);
destroy(&t2);
// p(-2) = (p(-1) + m2) * 2 - m0 (p(-1) result reused)
add(&px, &x2, &t1);
destroy(&px);
add(&t1, &t1, &t2);
destroy(&t1);
subtract(&t2, &x0, &px);
add(&py, &y2, &t1);
destroy(&py);
add(&t1, &t1, &t2);
destroy(&t1);
subtract(&t2, &y0, &py);
multiply(&px, &py, &qm2);
destroy(&px);
destroy(&py);
// p(inf) = m2
multiply(&x2, &y2, &qinf);
// x = (r(1) - r(-1))/2
subtract(&q1, &qm1, &px);
divideby2(&px);
// y = r(-1) - r(0)
subtract(&qm1, &q0, &py);
destroy(&qm1);
// z = (r(-2) - r(1))/3
subtract(&qm2, &q1, &pz);
destroy(&qm2);
destroy(&q1);
divideby3(&pz);
// r4 = r(inf)
r[4] = qinf;
// r3 = (y - z)/2 + 2*r(inf)
subtract(&py, &pz, &t1);
destroy(&pz);
divideby2(&t1);
add(&qinf, &qinf, &t2);
add(&t1, &t2, &r[3]);
destroy(&t1);
destroy(&t2);
// r2 = x + y - r(inf)
add(&px, &py, &t1);
destroy(&py);
subtract(&t1, &qinf, &r[2]);
destroy(&t1);
// r1 = x - r3
subtract(&px, &r[3], &r[1]);
destroy(&px);
// r0 = r(0)
r[0] = q0;
z->negative = x->negative ^ y->negative;
z->ndigits = r[4].ndigits + b * 4 + 1;
zv = z->digits = malloc(z->ndigits * sizeof(digit));
memcpy(z->digits, r[0].digits, r[0].ndigits * sizeof(digit));
memset(z->digits + r[0].ndigits, 0, (z->ndigits - r[0].ndigits) * sizeof(digit));
destroy(&r[0]);
for (i = 1; i < 5; ++i) {
int jj = i * b;
digit *v = r[i].digits;
digit carry = 0;
for (j = 0; j < r[i].ndigits; ++j, ++jj) {
zv[jj] += v[j] + carry;
if (zv[jj] >= BASE) {
zv[jj] -= BASE;
carry = 1;
} else {
carry = 0;
}
}
destroy(&r[i]);
}
return truncate(z);
}
const struct num *multiply(const struct num *x, const struct num *y, struct num *z) {
if (x->ndigits == 0 || y->ndigits == 0) {
z->negative = false;
z->ndigits = 0;
z->digits = NULL;
return z;
} else if (x->ndigits + y->ndigits > 40) {
return multiply_toom3(x, y, z);
} else if (x->ndigits < y->ndigits) {
const struct num *t = x; x = y; y = t;
}
int xn = x->ndigits, yn = y->ndigits;
digit *xv = x->digits, *yv = y->digits, *zv;
int i, j;
z->negative = x->negative ^ y->negative;
z->ndigits = xn + yn;
zv = z->digits = malloc(z->ndigits * sizeof(digit));
memset(zv, 0, sizeof(digit) * xn);
for (i = 0; i < yn; ++i) {
digit2 cur = yv[i];
digit2 carry = 0;
for (j = 0; j < xn; ++j) {
carry += cur * xv[j];
carry += zv[i+j];
zv[i+j] = carry % BASE;
carry /= BASE;
}
zv[i+xn] = carry;
}
return truncate(z);
}
const struct num *power(const struct num *x, int y, struct num *z) {
int mask = y;
while (mask != (mask | (mask >> 1))) mask |= mask >> 1; // propagate all bits
mask = (mask >> 2) + 1; // should be the second most significant bit of y
int specialcase = 0;
if (x->ndigits == 1 && x->digits[0] == 2) specialcase = 2;
struct num xx = *x;
while (mask > 0) {
struct num xxxx;
multiply(&xx, &xx, &xxxx);
if (xx.digits != x->digits) destroy(&xx);
if (y & mask) {
switch (specialcase) {
case 2: add(&xxxx, &xxxx, &xx); break;
default: multiply(x, &xxxx, &xx); break;
}
destroy(&xxxx);
} else {
xx = xxxx;
}
mask >>= 1;
}
*z = xx;
return z;
}
void display(const struct num *x, int threshold) {
int i;
if (threshold <= 0 || x->ndigits <= threshold * 2) {
printf(FIRSTDIGITFMT, x->digits[x->ndigits-1]);
for (i = x->ndigits - 2; i >= 0; --i) {
printf(DIGITFMT, x->digits[i]);
}
} else {
printf(FIRSTDIGITFMT, x->digits[x->ndigits-1]);
for (i = 1; i < threshold; ++i) {
printf(DIGITFMT, x->digits[x->ndigits-1-i]);
}
printf("...(%d digits omitted)...", DIGITSIZE * (x->ndigits - threshold*2));
for (i = threshold - 1; i >= 0; --i) {
printf(DIGITFMT, x->digits[i]);
}
}
printf("\n");
}
int main(void) {
struct num two = {.ndigits=1, .digits=(digit[1]){2}};
struct num result;
power(&two, 10000000, &result);
display(&result, 5);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment