Skip to content

Instantly share code, notes, and snippets.

@JAChapmanII
Created September 25, 2012 17:41
Show Gist options
  • Save JAChapmanII/3783355 to your computer and use it in GitHub Desktop.
Save JAChapmanII/3783355 to your computer and use it in GitHub Desktop.
integer.h
#include <stdlib.h>
#include <stdio.h>
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#include "gmp.h"
#define positive true
#define negative false
#define byte uint8_t
#define uint unsigned int
#define null NULL
/*
TODO:
* FIX MEMORY LEAKS
*/
typedef struct
{
byte *digit;
uint length;
bool sign;
} integer;
int max(int a, int b)
{
if(a >= b)
return a;
return b;
}
bool xor(bool a, bool b)
{
return ((a || b) && !(a && b));
}
void printInteger(integer *z)
{
int i = z->length - 1;
if(!z->sign)
{
printf("-");
}
while(i >= 0)
{
printf("%d",z->digit[i]);
--i;
}
}
void freeInteger(integer *z)
{
if(z == null)
{
printf("freeInteger failed\n");
return;
}
free(z->digit);
free(z);
}
void freeIntegerArray(integer **z, int length)
{
if(z == null)
{
printf("freeIntegerArray failed\n");
return;
}
for(int i = 0; i < length; ++i)
{
freeInteger(z[i]);
}
free(z);
}
integer* allocInteger(int digits)
{
integer *z = malloc(sizeof(integer));
if(z == null)
{
printf("allocInteger failed\n");
return (integer*)null;
}
z->digit = calloc(digits,sizeof(byte));
z->length = digits;
z->sign = positive;
return z;
}
integer** allocIntegerArray(int length)
{
integer **z = calloc(length,sizeof(integer));
if(z == null)
{
printf("allocIntegerArray failed");
return null;
}
return z;
}
integer** reallocIntegerArrary(integer **z, int length)
{
z = realloc(z,length*sizeof(integer));
if(z == null)
{
printf("reallocIntegerArray failed");
return null;
}
return z;
}
integer* copy(integer* z)
{
integer *x = allocInteger(z->length);
for(int i = 0; i < z->length; ++i)
{
x->digit[i] = z->digit[i];
}
x->sign = z->sign;
return x;
}
integer* trimInteger(integer *z, bool freeMem)
{
int n = z->length;
for(int i = n-1; i > 0; --i)
{
if(z->digit[i] == 0)
{
--n;
}else
{
break;
}
}
if(n == 0)
n = 1;
integer *x = allocInteger(n);
for(int i = 0; i < n; ++i)
{
x->digit[i] = z->digit[i];
}
if(!z->sign)
x->sign = negative;
if(freeMem)
freeInteger(z);
return x;
}
integer* fastTrimInteger(integer *z, int length, bool freeMem)
{
integer *x = allocInteger(length);
for(int i = 0; i < length; ++i)
{
x->digit[i] = z->digit[i];
}
if(freeMem)
freeInteger(z);
return x;
}
integer* arrayToInteger(byte array[], int n)
{
integer *z = allocInteger(n);
for(int i = 0; i < n; ++i)
{
z->digit[i] = array[i];
}
return z;
}
integer* rarrayToInteger(byte array[], int n)
{
integer *z = allocInteger(n);
for(int i = n - 1; i >= 0; --i)
{
z->digit[i] = array[(n - 1) - i];
}
return z;
}
integer* charArrayToInteger(char array[])
{
int n = strlen(array);
int i = 0;
int j = 0;
integer *z = allocInteger(n);
if(array[0] == '-')
{
++j;
z->sign = negative;
}
for(i, j; j < n; ++i, ++j)
{
z->digit[i] = (byte)(array[j] - '0');
}
return z;
}
integer* rcharArrayToInteger(char array[])
{
int n = strlen(array);
int i = n - 1;
int j = i;
integer *z = allocInteger(n);
if(array[0] == '-')
{
--j;
--i;
z->sign = negative;
}
for(i, j; j >= 0; --i, --j)
{
z->digit[i] = (byte)(array[(n - 1) - j] - '0');
}
return z;
}
char* integerToCharArray(integer *z)
{
char *array = calloc(z->length+1,sizeof(char));
for(int i = 0; i < z->length; ++i)
{
array[i] = (char)(z->digit[i] + '0');
}
array[z->length] = '\0';
return array;
}
char* rintegerToCharArray(integer *z)
{
char *array = calloc(z->length+1,sizeof(char));
for(int i = 0; i < z->length; ++i)
{
array[i] = (char)(z->digit[(z->length - 1) - i] + '0');
}
array[z->length] = '\0';
return array;
}
char* rintegerToCharArrayID(integer *z)
{
char *array = calloc(z->length+1,sizeof(char));
for(int i = 0; i < z->length; ++i)
{
array[i] = (char)(z->digit[(z->length - 1) - i]);
}
array[z->length] = '\0';
return array;
}
integer* intToInteger(int x)
{
if(x == 0)
{
integer *tempz = allocInteger(1);
tempz->digit[0] = 0;
return tempz;
}
int y = x;
int r;
int i = 0;
bool sign = positive;
integer *tempz = allocInteger(15);
if(y < 0)
{
sign = negative;
y*=-1;
}
while( y > 0)
{
r = y % 10;
y = y / 10;
tempz->digit[i] = r;
++i;
}
integer *z = fastTrimInteger(tempz,i,true);
z->sign = sign;
return z;
}
int integerToInt(integer *z, bool freeMem)
{
int result = 0;
if(z->length > 11)
{
printf("integer is too large for integerToInt conversion, integer was not freed\n");
return 0;
}else
{
int scale = 1;
for(int i = 0; i < z->length; ++i)
{
result+=(scale*z->digit[i]);
scale*=10;
}
if(!z->sign)
result*=-1;
if(freeMem)
freeInteger(z);
return result;
}
}
/*
@param integer , integer
@return int
@algorithm Compare two integers digit by digit starting at the first digit
of the larger integer, exit early if the digits are not equal.
@complexity O(n) time
Although it is O(n) the chance that two numbers will be = as n -> infinity
becomes increasing less likely, and since the algorithm exits with a result
at the first sign of non-equality the practical speed is faster then O(n)
as n becomes larger.
*/
int compare(integer *a, integer *b)
{
uint l_a = a->length;
uint l_b = b->length;
uint l = max(l_a,l_b);
int result = 0;
if(a->sign && !b->sign)
{
return 1;
}else if(!a->sign && b->sign)
{
return -1;
}
for(int i = l - 1; i >= 0; --i)
{
byte d_a;
byte d_b;
if(i < l_a)
{
d_a = a->digit[i];
}else
{
d_a = 0;
}
if(i < l_b)
{
d_b = b->digit[i];
}else
{
d_b = 0;
}
if(d_a > d_b)
{
result = 1;
return result;
}else if(d_a < d_b)
{
result = -1;
return result;
}
}
return result;
}
int compareToInt(integer *a, int B)
{
integer *b = intToInteger(B);
int result = compare(a,b);
freeInteger(b);
return result;
}
bool isEven(integer *a)
{
byte d = a->digit[0];
switch(d)
{
case 0:
case 2:
case 4:
case 6:
case 8:
return true;
case 1:
case 3:
case 5:
case 7:
case 9:
return false;
}
printf("isEven failed");
return false;
}
/*
@alias add
@param integer , integer
@return integer
@algorithm This is essientially standard addition applied to base ten
@complexity O(n) time
*/
integer* addition(integer *a, integer *b, bool freeMem)
{
uint length = max(a->length,b->length) + 1;
integer *z = allocInteger(length);
int carry = 0;
for(int i = 0; i < length; ++i)
{
int da;
int db;
int r;
if(i >= a->length)
da = 0;
else
da = a->digit[i];
if(i >= b->length)
db = 0;
else
db = b->digit[i];
if(!a->sign)
da*=-1;
if(!b->sign)
db*=-1;
int total = da + db + carry;
if(total < 10 && total >= 0)
{
r = total;
carry = 0;
}else if(total >= 10)
{
r = total - 10;
carry = 1;
}else
{
r = 10 + total;
carry = -1;
}
z->digit[i] = r;
}
z = trimInteger(z,true);
if(compare(a,b) > 0)
z->sign = a->sign;
else
z->sign = b->sign;
if(freeMem)
{
freeInteger(a);
freeInteger(b);
}
return z;
}
integer* add(integer *a, integer *b)
{
return addition(a,b,false);
}
static bool PRINTSUB = false;
integer* subtraction(integer *a, integer *b, bool freeMem)
{
integer *c;
integer *z;
if(compare(a,b) >= 0)
{
c = copy(b);
// TODO: here arm
if(c->sign == positive)
c->sign = negative;
else
c->sign = positive;
z = addition(a,c,false);
}else
{
if(PRINTSUB) {
printf("a < b: ");printInteger(a);printf(" ");printInteger(b);printf("\n");
}
// TODO: here jac
c = copy(a);
if(c->sign == positive)
c->sign = negative;
else
c->sign = positive;
z = addition(b,c,false);
z->sign = negative;
}
freeInteger(c);
if(freeMem)
{
freeInteger(a);
freeInteger(b);
}
return z;
}
integer* subtract(integer *a, integer *b)
{
return subtraction(a,b,false);
}
integer* increment(integer *a, int B, bool freeMem)
{
integer *b = intToInteger(B);
integer *z = addition(a,b,freeMem);
return z;
}
integer* random(integer *n)
{
int l = rand() % n->length + 1;
integer *temp = allocInteger(l);
for(int i = 0; i < l - 1; ++i)
{
temp->digit[i] = rand() % 10;
}
temp->digit[l-1] = 1;
if(compareToInt(temp,1) <= 1)
{
freeInteger(temp);
return intToInteger(1);
}else
{
integer *z = increment(temp,-1,false);
freeInteger(temp);
return z;
}
}
integer* randm(int n)
{
integer *temp = allocInteger(n);
for(int i = 0; i < n; ++i)
{
temp->digit[i] = rand() % 10;
}
return temp;
}
/*
@alias multiply
@param integer , integer
@return integer
@algorithm This is essientially a long-division algorithm
@complexity O(n^2) time
This function is called naive because it takes an elementary approach to
multiplication (i.e. the kind learned in elementary school).
*/
integer* slowMultiplication(integer *a, integer *b, bool freeMem)
{
integer *x;
integer *z = allocInteger(1);
z->sign = positive;
uint length = a->length + b->length;
uint l = max(a->length,b->length);
byte *d = calloc(length,sizeof(byte));
int carry = 0;
for(int i = 0; i < b->length; ++i)
{
d = calloc(length,sizeof(byte));
for(int j = 0; j < a->length + 1; ++j)
{
int total;
int r;
int da;
if(j >= a->length)
da = 0;
else
da = a->digit[j];
int db = b->digit[i];
total = da*db + carry;
r = total % 10;
carry = total / 10;
d[j+i] = r;
}
x = arrayToInteger(d,length);
z = add(z,x);
freeInteger(x);
carry = 0;
free(d);
}
z = trimInteger(z,true);
// TODO: here arm
if(a->sign == negative || b->sign == negative)
z->sign = negative;
if(a->sign == negative && b->sign == negative)
z->sign = positive;
//z->sign = positive; //xor(a->sign,b->sign);
if(freeMem)
{
freeInteger(a);
freeInteger(b);
}
return z;
}
integer* multiply(integer *a, integer *b)
{
return slowMultiplication(a,b,false);
}
integer* divide2(integer *z, bool freeMem)
{
uint n = z->length;
byte *digits = calloc(n,sizeof(byte));
for(int i = n; i > 0; --i)
{
int first;
int second;
if(i == n)
{
first = 0;
second = z->digit[n-1];
}else
{
first = z->digit[i];
second = z->digit[i-1];
}
if(first % 2 == 0)
{
switch(second)
{
case 0:
case 1:
digits[n-i] = 0;
break;
case 2:
case 3:
digits[n-i] = 1;
break;
case 4:
case 5:
digits[n-i] = 2;
break;
case 6:
case 7:
digits[n-i] = 3;
break;
case 8:
case 9:
digits[n-i] = 4;
break;
}
}else
{
switch(second)
{
case 0:
case 1:
digits[n-i] = 5;
break;
case 2:
case 3:
digits[n-i] = 6;
break;
case 4:
case 5:
digits[n-i] = 7;
break;
case 6:
case 7:
digits[n-i] = 8;
break;
case 8:
case 9:
digits[n-i] = 9;
break;
}
}
}
integer *x = rarrayToInteger(digits,n);
free(digits);
if(freeMem)
free(z);
return x;
}
int modInt(integer *z, int m, bool freeMem)
{
int mod = z->digit[z->length - 1];
for(int i = z->length - 1; i >= 0; --i)
{
mod = (mod * 10 + z->digit[i]) % m;
}
if(freeMem)
freeInteger(z);
return mod;
}
/*int modIntTest(int a, int b)
{
int left = 0, right = a;
while( left < right )
{
int m = (left + right) / 2;
if ( a - m*b >= b )
left = m + 1;
else
right = m;
}
return a - left*b;
}*/
integer* mod(integer *a, integer *b, bool freeMem)
{
if(compareToInt(b,0) == 0)
{
printf("Modular by 0, returning 1");
return intToInteger(1);
}
if(compare(a,b) < 0)
{
integer *z = copy(a);
if(freeMem)
{
freeInteger(a);
freeInteger(b);
}
return z;
}
integer *left = intToInteger(0);
integer *right = copy(a);
while(compare(left,right) < 0)
{
integer *m = divide2(add(left,right),true);
integer *temp = multiply(m,b);
integer *temp2 = subtract(a,temp);
if(compare(temp2,b) >= 0) //&& compareToInt(temp2,0) > 0)
{
freeInteger(left);
left = increment(m,1,false);
}else
{
freeInteger(right);
right = copy(m);
}
freeInteger(temp2);
freeInteger(temp);
freeInteger(m);
}
integer *temp = multiply(left,b);
integer *z = subtract(a,temp);
freeInteger(temp);
freeInteger(left);
freeInteger(right);
if(freeMem)
{
freeInteger(a);
freeInteger(b);
}
return z;
}
integer* divide(integer *a, integer *b)
{
bool leftQ = true;
if(compareToInt(b,0) == 0)
{
printf("Division by 0, returning 1");
return intToInteger(1);
}
if(compareToInt(a,0) == 0)
{
integer *z = intToInteger(0);
return z;
}
if(compare(a,b) < 0)
{
integer *z = intToInteger(0);
return z;
}else
{
integer *left = intToInteger(0);
integer *right = copy(a);
integer *ans = intToInteger(0);
while(compare(left,right) < 0)
{
integer *m = divide2(add(left,right),true);
integer *temp = multiply(m,b);
integer *temp2 = subtract(a,temp);
if(compare(temp2,b) > 0)
{
freeInteger(left);
freeInteger(ans);
left = increment(m,1,false);
ans = copy(left);
leftQ = true;
}else if(compare(temp2,b) <= 0)
{
freeInteger(right);
freeInteger(ans);
right = copy(m);
ans = copy(right);
leftQ = false;
}
freeInteger(temp2);
freeInteger(temp);
freeInteger(m);
}
integer *temp = trimInteger(left,false);
integer *z;
integer *temp2 = mod(a,b,false);
if(compareToInt(temp2,0) == 0)
{
z = increment(temp,1,false);
}else
{
z = copy(temp); //increment(temp,1,false);
}
//freeInteger(temp);
freeInteger(left);
freeInteger(right);
freeInteger(temp);
freeInteger(temp2);
return z;
}
}
integer* expnr(integer *base, integer *exp)
{
if(compareToInt(exp,0) == 0)
return intToInteger(1);
if(isEven(exp))
{
integer *temp1 = divide2(exp,false);
integer *temp2 = expnr(base,temp1);
integer *temp3 = multiply(temp2,temp2);
freeInteger(temp1);
freeInteger(temp2);
return temp3;
}else
{
integer *temp0 = increment(exp,-1,false);
integer *temp1 = divide2(temp0,false);
integer *temp2 = expnr(base,temp1);
integer *temp3 = multiply(temp2,temp2);
integer *temp4 = multiply(base,temp3);
freeInteger(temp0);
freeInteger(temp1);
freeInteger(temp2);
freeInteger(temp3);
return temp4;
}
}
integer* powMod(integer *base, integer *exp, integer *m)
{
if(compareToInt(exp,0) == 0)
return intToInteger(1);
if(isEven(exp))
{
integer *temp1 = divide2(exp,false);
integer *temp2 = powMod(base,temp1,m);
integer *temp3 = multiply(temp2,temp2);
integer *temp4 = mod(temp3,m,false);
freeInteger(temp1);
freeInteger(temp2);
freeInteger(temp3);
return temp4;
}else
{
integer *temp0 = increment(exp,-1,false);
integer *temp1 = divide2(temp0,false);
integer *temp2 = powMod(base,temp1,m);
integer *temp3 = multiply(temp2,temp2);
integer *temp4 = multiply(base,temp3);
integer *temp5 = mod(temp4,m,false);
freeInteger(temp0);
freeInteger(temp1);
freeInteger(temp2);
freeInteger(temp3);
freeInteger(temp4);
return temp5;
}
}
int probablyPrime(integer *test, int trials)
{
if(isEven(test))
{
if(compareToInt(test,2) == 0)
return 1;
return 0;
}
integer *temp = increment(test,-1,false);
for(int i = 0; i < trials; ++i)
{
integer *k = random(test);
integer *a = powMod(k,temp,test);
if(compareToInt(a,1) != 0)
{
freeInteger(k);
freeInteger(a);
freeInteger(temp);
return 0;
}
freeInteger(k);
freeInteger(a);
}
freeInteger(temp);
return 1;
}
/*
function extended_gcd(a, b)
x := 0 lastx := 1
y := 1 lasty := 0
while b ? 0
quotient := a div b
(a, b) := (b, a mod b)
(x, lastx) := (lastx - quotient*x, x)
(y, lasty) := (lasty - quotient*y, y)
return (lastx, lasty)
*/
integer* extendedEuler(integer *a, integer *b)
{
integer *x = intToInteger(0);
integer *y = intToInteger(1);
integer *lastx = intToInteger(1);
integer *lasty = intToInteger(0);
integer *aTemp = copy(a);
integer *bTemp = copy(b);
printf("x: "); printInteger(x);
printf(" y: "); printInteger(y);
printf(" lastx: "); printInteger(lastx);
printf(" lasty: "); printInteger(lasty); printf("\n");
printf("aTemp: "); printInteger(aTemp);
printf(" bTemp: "); printInteger(bTemp); printf("\n");
while(compareToInt(bTemp,0) != 0)
{
integer *qSave = divide(aTemp, bTemp);
integer *rSave = mod(aTemp, bTemp, false);
integer *q = divide(aTemp,bTemp);
integer *temp1 = mod(aTemp,bTemp,false);
freeInteger(aTemp);
aTemp = copy(bTemp);
freeInteger(bTemp);
bTemp = copy(temp1);
/*
printf("aTemp: "); printInteger(aTemp);
printf(" bTemp: "); printInteger(bTemp);
printf(" q: "); printInteger(q);
printf(" r: "); printInteger(temp1); printf("\n");
*/
integer *xTemp1 = multiply(q,x);
integer *xTemp2 = subtract(lastx,xTemp1);
freeInteger(lastx);
lastx = copy(x);
freeInteger(x);
x = copy(xTemp2);
freeInteger(xTemp1);
freeInteger(xTemp2);
integer *yTemp1 = multiply(q,y);
PRINTSUB = true;
integer *yTemp2 = subtract(lasty,yTemp1);
PRINTSUB = false;
printf("-- q: "); printInteger(q); printf(", y: "); printInteger(y);
printf(", ly: "); printInteger(lasty); printf("\n");
printf("-- yTemp1 (q * y): "); printInteger(yTemp1);
printf(", yTemp2 (ly - qy): "); printInteger(yTemp2); printf("\n");
freeInteger(lasty);
lasty = copy(y);
freeInteger(y);
y = copy(yTemp2);
freeInteger(yTemp1);
freeInteger(yTemp2);
freeInteger(q);
freeInteger(temp1);
/*
printf("x: "); printInteger(x);
printf(" y: "); printInteger(y);
printf(" lastx: "); printInteger(lastx);
printf(" lasty: "); printInteger(lasty); printf("\n");
*/
printInteger(aTemp); printf(" ");
printInteger(bTemp); printf(" ");
printInteger(qSave); printf(" ");
printInteger(rSave); printf(" ");
printInteger(x); printf(" ");
printInteger(lastx); printf(" ");
printInteger(y); printf(" ");
printInteger(lasty); printf(" ");
printf("\n");
freeInteger(qSave);
freeInteger(rSave);
}
freeInteger(x); freeInteger(y);
freeInteger(aTemp); freeInteger(bTemp);
freeInteger(lastx);
printf("lasty: "); printInteger(lasty); printf("\n");
if(!lasty->sign)
printf("lastx is negative");
return lasty;
}
void extendedEulerR(integer x, integer y, integer a, integer b)
{
if(compareToInt(&b,0) == 0)
{
x = *intToInteger(1);
y = *intToInteger(0);
}else
{
integer q = *divide(&a,&b);
integer r = *mod(&a,&b,false);
integer s, t;
extendedEulerR(s,t,b,r);
x = t;
integer temp = *multiply(&q,&t);
integer temp2 = *subtract(&s,&temp);
y = temp2;
}
}
integer* gcd(integer *a, integer *b)
{
integer *bTemp = copy(b);
integer *aTemp = copy(a);
integer *t;
while(compareToInt(bTemp,0) != 0)
{
t = bTemp;
bTemp = mod(aTemp,bTemp,false);
aTemp = t;
}
return aTemp;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment