Skip to content

Instantly share code, notes, and snippets.

@maedoc
Created February 21, 2017 15:08
Show Gist options
  • Save maedoc/4359a28333a5f20ef8797c29ec5a7186 to your computer and use it in GitHub Desktop.
Save maedoc/4359a28333a5f20ef8797c29ec5a7186 to your computer and use it in GitHub Desktop.
Automate loading Loopy kernels as ctypes functions
import os
import numpy as np
import ctypes
import loopy as lp
from loopy.target import c
import cgen
import subprocess
import tempfile
class Compiler:
source_suffix = 'c'
default_cc = 'gcc'
default_cflags = '-std=c99 -O3'.split()
default_ldflags = []
def __init__(self, cc=None, cflags=None, ldflags=None):
self.cc = cc or self.default_cc
self.cflags = cflags or self.default_cflags[:]
self.ldflags = ldflags or self.default_ldflags[:]
self.tempdir = tempfile.TemporaryDirectory()
def tempname(self, name):
return os.path.join(self.tempdir.name, name)
def build(self, code) -> ctypes.CDLL:
c_fname = self.tempname('code.' + self.source_suffix)
obj_fname = self.tempname('code.o')
dll_fname = self.tempname('code.so')
with open(c_fname, 'w') as fd:
fd.write(code)
self._call([self.cc] + self.cflags + ['-fPIC', '-c', c_fname])
self._call([self.cc] + self.cflags + self.ldflags +
['-shared', obj_fname, '-o', dll_fname])
return ctypes.CDLL(dll_fname)
def _call(self, args, **kwargs):
subprocess.check_call(
args, cwd=self.tempdir.name, **kwargs
)
class CompiledKernel:
def __init__(self, knl: lp.LoopKernel, comp: Compiler=None):
assert isinstance(knl.target, c.CTarget)
self.knl = knl
self.code, _ = lp.generate_code(knl)
self.comp = comp or Compiler()
self.dll = self.comp.build(self.code)
self.func_decl, = c.generate_header(knl)
self.arg_info = []
self.visit(self.func_decl)
self.name = self.func_decl.subdecl.name
restype = self.func_decl.subdecl.typename
if restype == 'void':
self.restype = None
else:
raise ValueError('Unhandled restype %r' % (restype, ))
self.fn = getattr(self.dll, self.name)
self.fn.restype = self.restype
self.fn.argtypes = [ctype for name, ctype in self.arg_info]
def __call__(self, *args):
args_ = []
for arg, arg_t in zip(args, self.fn.argtypes):
if hasattr(arg, 'ctypes'):
arg_ = arg.ctypes.data_as(arg_t)
else:
arg_ = arg_t(arg)
args_.append(arg_)
return self.fn(*args_)
def append_arg(self, name, dtype, pointer=False):
self.arg_info.append((
name,
self.dtype_to_ctype(dtype, pointer=pointer)
))
def visit_const(self, node: cgen.Const):
if isinstance(node.subdecl, cgen.RestrictPointer):
self.visit_pointer(node.subdecl)
else:
pod = node.subdecl # type: cgen.POD
self.append_arg(pod.name, pod.dtype)
def visit_pointer(self, node: cgen.RestrictPointer):
pod = node.subdecl # type: cgen.POD
self.append_arg(pod.name, pod.dtype, pointer=True)
def visit(self, func_decl):
for i, arg in enumerate(func_decl.arg_decls):
if isinstance(arg, cgen.Const):
self.visit_const(arg)
elif isinstance(arg, cgen.RestrictPointer):
self.visit_pointer(arg)
else:
print('unhandled type for arg %r' % (arg, ))
def dtype_to_ctype(self, dtype, pointer=False):
registry = target.get_dtype_registry().wrapped_registry
typename = registry.dtype_to_ctype(dtype)
basetype = getattr(ctypes, 'c_' + typename)
if pointer:
return ctypes.POINTER(basetype)
return basetype
target = c.CTarget()
knl = lp.make_kernel("{ [i]: 0<=i<n }", "out[i] = 2*a[i]", target=target)
typed = lp.add_dtypes(knl, {'a': 'f'})
code, _ = lp.generate_code(typed)
fn = CompiledKernel(typed)
a, out = np.zeros((2, 10), 'f')
a[:] = np.r_[:a.size]
fn(a, 10, out)
assert np.allclose(out, a * 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment