Created
February 21, 2017 15:08
-
-
Save maedoc/4359a28333a5f20ef8797c29ec5a7186 to your computer and use it in GitHub Desktop.
Automate loading Loopy kernels as ctypes functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
import ctypes | |
import loopy as lp | |
from loopy.target import c | |
import cgen | |
import subprocess | |
import tempfile | |
class Compiler: | |
source_suffix = 'c' | |
default_cc = 'gcc' | |
default_cflags = '-std=c99 -O3'.split() | |
default_ldflags = [] | |
def __init__(self, cc=None, cflags=None, ldflags=None): | |
self.cc = cc or self.default_cc | |
self.cflags = cflags or self.default_cflags[:] | |
self.ldflags = ldflags or self.default_ldflags[:] | |
self.tempdir = tempfile.TemporaryDirectory() | |
def tempname(self, name): | |
return os.path.join(self.tempdir.name, name) | |
def build(self, code) -> ctypes.CDLL: | |
c_fname = self.tempname('code.' + self.source_suffix) | |
obj_fname = self.tempname('code.o') | |
dll_fname = self.tempname('code.so') | |
with open(c_fname, 'w') as fd: | |
fd.write(code) | |
self._call([self.cc] + self.cflags + ['-fPIC', '-c', c_fname]) | |
self._call([self.cc] + self.cflags + self.ldflags + | |
['-shared', obj_fname, '-o', dll_fname]) | |
return ctypes.CDLL(dll_fname) | |
def _call(self, args, **kwargs): | |
subprocess.check_call( | |
args, cwd=self.tempdir.name, **kwargs | |
) | |
class CompiledKernel: | |
def __init__(self, knl: lp.LoopKernel, comp: Compiler=None): | |
assert isinstance(knl.target, c.CTarget) | |
self.knl = knl | |
self.code, _ = lp.generate_code(knl) | |
self.comp = comp or Compiler() | |
self.dll = self.comp.build(self.code) | |
self.func_decl, = c.generate_header(knl) | |
self.arg_info = [] | |
self.visit(self.func_decl) | |
self.name = self.func_decl.subdecl.name | |
restype = self.func_decl.subdecl.typename | |
if restype == 'void': | |
self.restype = None | |
else: | |
raise ValueError('Unhandled restype %r' % (restype, )) | |
self.fn = getattr(self.dll, self.name) | |
self.fn.restype = self.restype | |
self.fn.argtypes = [ctype for name, ctype in self.arg_info] | |
def __call__(self, *args): | |
args_ = [] | |
for arg, arg_t in zip(args, self.fn.argtypes): | |
if hasattr(arg, 'ctypes'): | |
arg_ = arg.ctypes.data_as(arg_t) | |
else: | |
arg_ = arg_t(arg) | |
args_.append(arg_) | |
return self.fn(*args_) | |
def append_arg(self, name, dtype, pointer=False): | |
self.arg_info.append(( | |
name, | |
self.dtype_to_ctype(dtype, pointer=pointer) | |
)) | |
def visit_const(self, node: cgen.Const): | |
if isinstance(node.subdecl, cgen.RestrictPointer): | |
self.visit_pointer(node.subdecl) | |
else: | |
pod = node.subdecl # type: cgen.POD | |
self.append_arg(pod.name, pod.dtype) | |
def visit_pointer(self, node: cgen.RestrictPointer): | |
pod = node.subdecl # type: cgen.POD | |
self.append_arg(pod.name, pod.dtype, pointer=True) | |
def visit(self, func_decl): | |
for i, arg in enumerate(func_decl.arg_decls): | |
if isinstance(arg, cgen.Const): | |
self.visit_const(arg) | |
elif isinstance(arg, cgen.RestrictPointer): | |
self.visit_pointer(arg) | |
else: | |
print('unhandled type for arg %r' % (arg, )) | |
def dtype_to_ctype(self, dtype, pointer=False): | |
registry = target.get_dtype_registry().wrapped_registry | |
typename = registry.dtype_to_ctype(dtype) | |
basetype = getattr(ctypes, 'c_' + typename) | |
if pointer: | |
return ctypes.POINTER(basetype) | |
return basetype | |
target = c.CTarget() | |
knl = lp.make_kernel("{ [i]: 0<=i<n }", "out[i] = 2*a[i]", target=target) | |
typed = lp.add_dtypes(knl, {'a': 'f'}) | |
code, _ = lp.generate_code(typed) | |
fn = CompiledKernel(typed) | |
a, out = np.zeros((2, 10), 'f') | |
a[:] = np.r_[:a.size] | |
fn(a, 10, out) | |
assert np.allclose(out, a * 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment