Last active
July 19, 2022 14:10
-
-
Save bergkvist/e1d82caa8bb62a51161d2de3e95f0e7d to your computer and use it in GitHub Desktop.
Allow for writing C extensions inline in IPython
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.core.magic import register_cell_magic | |
import multiprocessing as mp | |
from functools import wraps | |
import importlib.util | |
import traceback | |
import sysconfig | |
import tempfile | |
import secrets | |
import sys | |
import os | |
import numpy | |
@register_cell_magic | |
def cfunc(line, cell): | |
extname = secrets.token_hex() | |
name, ref, args, docs = (entry.strip() for entry in line.split(',', 3)) | |
methodname = eval(name) | |
code = generate_c_extension(extname, f'{{{name}, {ref}, {args}, {docs}}}', cell) | |
method = build_extension_method(extname, methodname, code) | |
method.__name__ = methodname | |
method.__docs__ = eval(docs) | |
globals()[methodname] = method | |
def build_extension_method(extname, methodname, c_code): | |
with tempfile.NamedTemporaryFile('w+', suffix='.c') as f1: | |
f1.write(c_code) | |
f1.seek(0) | |
with tempfile.NamedTemporaryFile('rb+', suffix='.so') as f2: | |
os.system( | |
f'gcc -shared -fPIC -O2 {f1.name} -o {f2.name} ' | |
f'-I{sysconfig.get_path("include")} -I{numpy.get_include()} ' | |
f'-L{sysconfig.get_path("data")}/lib -lpython{sysconfig.get_python_version()}' | |
) | |
spec = importlib.util.spec_from_file_location(extname, f2.name) | |
module = importlib.util.module_from_spec(spec) | |
@forked | |
def method(*args, **kwargs): | |
spec.loader.exec_module(module) | |
return getattr(module, methodname)(*args, **kwargs) | |
return method | |
def generate_c_extension(extname, methoddef, method): | |
return f''' | |
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION | |
#include <Python.h> | |
#include <numpy/arrayobject.h> | |
{method} | |
static PyMethodDef methods[] = {{ | |
{methoddef}, | |
}}; | |
static struct PyModuleDef moduledef = {{ | |
PyModuleDef_HEAD_INIT, | |
"{extname}", | |
NULL, | |
-1, | |
methods | |
}}; | |
PyMODINIT_FUNC PyInit_{extname}(void) {{ | |
PyObject *module = PyModule_Create(&moduledef); | |
import_array(); | |
return module; | |
}} | |
''' | |
def forked(fn): | |
@wraps(fn) | |
def call(*args, **kwargs): | |
ctx = mp.get_context('fork') | |
q = ctx.Queue(1) | |
def target(): | |
try: | |
q.put((fn(*args, **kwargs), None)) | |
except BaseException as e: | |
q.put((None, e)) | |
p = ctx.Process(target=target) | |
p.start() | |
p.join() | |
if q.empty(): raise SystemExit(p.exitcode) | |
result, error = q.get(block=False) | |
if error: | |
raise error | |
return result | |
return call |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment