Created
December 23, 2019 07:26
-
-
Save michelp/231c3fe771248ae856a57c46aa19e284 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from textwrap import dedent | |
from operator import methodcaller | |
from numba import cfunc, types, jit, carray, cffi_support | |
from numba.types import Record, CPointer, void, double, uint64 | |
from pygraphblas import Matrix, BinaryOp, lib, ffi as gbffi | |
from cffi import FFI | |
def type_head(name): | |
return dedent(""" | |
typedef struct %s { | |
""" % name) | |
def type_body(members): | |
return ";\n".join(members) + ';' | |
def type_tail(name): | |
return dedent(""" | |
} %s; | |
""" % name) | |
def build_type_def(typ, members): | |
return (type_head(typ) + | |
type_body(members) + | |
type_tail(typ)) | |
def binop_name(typ, name): | |
return '{0}_{1}_binop_function'.format(typ, name) | |
def build_binop_def(typ, name): | |
return dedent(""" | |
typedef void (*{0})({1}*, {1}*, {1}*); | |
""".format(binop_name(typ, name), typ)) | |
ffi = FFI() | |
def new_binop(func, udt_type): | |
o = gbffi.new('GrB_BinaryOp*') | |
typ = udt_type[0] | |
lib.GrB_BinaryOp_new(o, gbffi.cast('GxB_binary_function', func.address), typ, typ, typ) | |
return o | |
def binop_sig(type_name, func_name): | |
ffi.cdef(build_binop_def(type_name, func_name)) | |
return cffi_support.map_type(ffi.typeof(binop_name(type_name, func_name)), use_record_dtype=True) | |
def binop(udt, func_name): | |
sig = binop_sig(udt.type_name, func_name) | |
def inner(func): | |
jitfunc = jit(func, nopython=True) | |
@cfunc(sig) | |
def wrapper(z_, x_, y_): | |
z = carray(z_, 1)[0] | |
x = carray(x_, 1)[0] | |
y = carray(y_, 1)[0] | |
jitfunc(z, x, y) | |
return wrapper | |
return inner | |
class UDT: | |
def __init__(self, type_name, members): | |
self.type_name = type_name | |
self.members = map(methodcaller('split'), members) | |
ffi.cdef(build_type_def(type_name, members)) | |
t = gbffi.new('GrB_Type*') | |
lib.GrB_Type_new(t, ffi.sizeof(type_name)) | |
cffi_support.map_type(ffi.typeof(type_name), use_record_dtype=True) | |
self.udt = t | |
def from_tuple(self, *args): | |
data = ffi.new('%s[1]' % self.type_name) | |
for (_, name), val in zip(self.members, args): | |
setattr(data[0], name, val) | |
addr = int(ffi.cast('size_t', data)) | |
return addr | |
def to_tuple(self, ptr): | |
return ffi.cast('%s*' % self.type_name, ptr) | |
myudt = UDT('foo', ['double w', 'uint64_t h', 'uint64_t p']) | |
@binop(myudt, 'min') | |
def bf_min(z, x, y): | |
z.w = x.w if x.w < y.w else y.w | |
mymin = new_binop(bf_min, myudt.udt) | |
op = BinaryOp('plus', 'foo', mymin[0]) | |
A = Matrix.from_type(myudt.udt[0], 10, 10) | |
B = Matrix.from_type(myudt.udt[0], 10, 10) | |
z = myudt.from_tuple(0.1, 2) | |
x = myudt.from_tuple(0.4, 2, 2) | |
y = myudt.from_tuple(0.2, 3, 3) | |
A[0,0] = z | |
B[1,1] = x | |
with op: | |
print((A + B).nvals) | |
myudt.to_tuple(z).p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment