Skip to content

Instantly share code, notes, and snippets.

@jbedo
Created November 30, 2011 01:08
Show Gist options
  • Save jbedo/1407483 to your computer and use it in GitHub Desktop.
Save jbedo/1407483 to your computer and use it in GitHub Desktop.
Matrix multiplication with Hilbert space-filling curves
#include<u.h>
#include<libc.h>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
/* Hilbert curve functions (from Wikipedia) */
/* rotate/flip a quadrant appropriately */
void rot(int n, int *x, int *y, int rx, int ry)
{
int t;
if (ry == 0) {
if (rx == 1) {
*x = n - 1 - *x;
*y = n - 1 - *y;
}
t = *x;
*x = *y;
*y = t;
}
}
/* (x,y) ↦ d */
int xy2d (int n, int x, int y)
{
int rx, ry, s, d = 0;
for (s = n / 2; s > 0; s /= 2) {
rx = (x & s) > 0;
ry = (y & s) > 0;
d += s * s * ((3 * rx) ^ ry);
rot(s, &x, &y, rx, ry);
}
return d;
}
/* d ↦ (x,y) */
void d2xy(int n, int d, int *x, int *y)
{
int rx, ry, s, t = d;
*x = *y = 0;
for (s = 1; s < n; s *= 2) {
rx = 1 & (t / 2);
ry = 1 & (t ^ rx);
rot(s, x, y, rx, ry);
*x += s * rx;
*y += s * ry;
t /= 4;
}
}
/* Memory */
void *
emalloc(ulong sz)
{
void *p;
p = malloc(sz);
if(p == 0)
sysfatal("emalloc: %r");
return p;
}
void *
cmalloc(ulong sz)
{
void *p;
p = emalloc(sz);
memset(p, 0, sz);
return p;
}
void *
erealloc(void *p, ulong sz)
{
p = realloc(p, sz);
if(p == 0)
sysfatal("erealloc: %r");
return p;
}
/* Matrices */
typedef struct matrix matrix;
struct matrix{
double *data;
int nr, nc, n;
int *Δi, *Δj;
};
typedef struct miter miter;
struct miter{
int *pj;
double *value;
matrix *x;
int valid;
};
matrix *
nmatrix(matrix *x, int nr, int nc)
{
int i, j, d;
int d′;
int *pi, *pj;
if(x == nil)
x = cmalloc(sizeof(*x));
x->nr = nr;
x->nc = nc;
if(nr > 0 && nc > 0){
x->n = 1 << ceil(log(MAX(nr, nc)) / log(2));
x->data = erealloc(x->data, sizeof(*x->data) * x->n * x->n);
x->Δi = erealloc(x->Δi, sizeof(*x->Δi) * (x->nr * x->nc + 1));
x->Δj = erealloc(x->Δj, sizeof(*x->Δj) * (x->nr * x->nc + 1));
/* Populate delta arrays */
pj = x->Δj;
for(i = d′ = 0; i < x->nr; i++){
for(j = 1; j < x->nc; j++){
d = xy2d(x->n, i, j);
*pj++ = d - d′;
d′ = d;
}
if(i < x->nr - 1){
d = xy2d(x->n, i + 1, 0);
*pj++ = d - d′;
d′ = d;
}
}
*pj = 0;
pi = x->Δi;
for(j = d′ = 0; j < x->nc; j++){
for(i = 1; i < x->nr; i++){
d = xy2d(x->n, i, j);
*pi++ = d - d′;
d′ = d;
}
if(j < x->nc - 1){
d = xy2d(x->n, 0, j + 1);
*pi++ = d - d′;
d′ = d;
}
}
*pi = 0;
}
return x;
}
void
dmatrix(matrix *x)
{
if(x->data != nil){
free(x->data);
free(x->Δi);
free(x->Δj);
}
free(x);
}
#define mget(x, i, j) ((x)->data[xy2d((x)->n, (i), (j))])
void
niter(miter *i, matrix *x)
{
i->pj = x->Δj;
i->valid = x->nr > 0 && x->nc > 0;
i->value = x->data;
}
void
next(miter *i)
{
if(*i->pj == 0){
i->valid = 0;
return;
}
i->value += *i->pj++;
}
void
zero(matrix *a)
{
memset(a->data, 0, sizeof(*a->data) * a->n * a->n);
}
matrix *
mdot(matrix *a, matrix *b, matrix *c)
{
double *ap, *ap′, *bp, *cp;
int *pi, *pi′, *pj, *pk;
int col, icol;
c = nmatrix(c, a->nr, b->nc);
zero(c);
ap′ = a->data;
bp = b->data;
cp = c->data;
pi′ = a->Δj;
pj = b->Δi;
pk = c->Δj;
for(col = 1;; cp += *pk++, col++){
ap = ap′;
pi = pi′;
for(icol = 0; icol < a->nc; icol++, ap += *pi++, bp += *pj++){
*cp += *ap * *bp;
}
if(col == c->nc){
ap′ = ap;
pi′ = pi;
bp = b->data;
pj = b->Δi;
col = 0;
}
if(*pk == 0)
break;
}
return c;
}
double *
tdot(double *a, double *b, double *c, uint n)
{
uint i, j, k;
c = erealloc(c, sizeof(*c) * n * n);
memset(c, 0, sizeof(*c) * n * n);
for(i = 0; i < n; i++)
for(j = 0; j < n; j++)
for(k = 0; k < n; k++)
c[i * n + j] += a[i * n + k] * b[k * n + j];
return c;
}
void
usage(void)
{
fprint(2, "%s: [-n msize]\n", argv0);
exits("usage");
}
void
main(int argc, char **argv)
{
int n = 128;
uvlong begin, end;
matrix *x, *y;
miter it;
double *sx, *sy, *px;
ARGBEGIN{
case 'n':
n = atoi(EARGF(usage()));
break;
case 'h':
default:
usage();
}ARGEND;
if(n <= 1)
sysfatal("Matrix size must be ≥ 2\n");
cycles(&begin);
x = nmatrix(nil, n, n);
cycles(&end);
print("%ulld\t", end - begin);
sx = emalloc(sizeof(*sx) * n * n);
for(niter(&it, x), px = sx; it.valid; next(&it), px++)
*it.value = *px = frand() - 0.5;
cycles(&begin);
y = mdot(x, x, nil);
cycles(&end);
print("%ulld\t", end - begin);
dmatrix(x);
dmatrix(y);
cycles(&begin);
sy = tdot(sx, sx, nil, n);
cycles(&end);
print("%ulld\n", end - begin);
free(sy);
free(sx);
exits(0);
}
@jbedo
Copy link
Author

jbedo commented Mar 30, 2021

Thanks for your interest! It's a bit complicated because it's written for plan 9 not linux/mac. You'll need to compile it with plan9port (https://github.com/9fans/plan9port) with:
9 9c hilbert.c && 9 9l -o hilbert hilbert.o

There are a few caveats:

  1. GCC doesn't like the unicode prime character on line 101, so that will have to be renamed
  2. on line 110 ceil() returns a double and needs to be explicitly cast to an int
  3. the cycles() function used to time the function invocations doesn't exist in p9p. Some other timing function will need to be substituted.

@ofou
Copy link

ofou commented Jul 6, 2021

Thanks @jbedo! 🙌🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment