-
-
Save den-run-ai/0d4e714c5b4e7123663e to your computer and use it in GitHub Desktop.
Example of multithreading a numba function by releasing the GIL through ctypes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from timeit import repeat | |
import threading | |
from ctypes import pythonapi, c_void_p | |
import math | |
import numpy as np | |
try: | |
import numexpr as ne | |
nthreads = ne.ncores | |
except ImportError: | |
ne = None | |
nthreads = 2 | |
from numba import jit, double, autojit, void | |
size = 1e6 | |
def timefunc(correct, func, *args, **kwargs): | |
print func.__name__.ljust(20), | |
# Warming up | |
res = func(*args, **kwargs) | |
if correct is not None: | |
assert np.allclose(res, correct) | |
# time it | |
print '{:>5.0f} ms'.format(min(repeat(lambda: func(*args, **kwargs), | |
number=5, repeat=2)) * 1000) | |
return res | |
def make_singlethread(inner_func): | |
def func(*args): | |
length = len(args[0]) | |
result = np.empty(length, dtype=np.float64) | |
inner_func(result, *args) | |
return result | |
return func | |
def make_multithread(inner_func, numthreads): | |
def func_mt(*args): | |
length = len(args[0]) | |
result = np.empty(length, dtype=np.float64) | |
args = (result,) + args | |
chunklen = (length + 1) // numthreads | |
chunks = [[arg[i * chunklen:(i + 1) * chunklen] for arg in args] | |
for i in range(numthreads)] | |
threads = [threading.Thread(target=inner_func, args=chunk) | |
for chunk in chunks[:-1]] | |
for thread in threads: | |
thread.start() | |
# the main thread handles the last chunk | |
inner_func(*chunks[-1]) | |
for thread in threads: | |
thread.join() | |
return result | |
return func_mt | |
savethread = pythonapi.PyEval_SaveThread | |
savethread.argtypes = [] | |
savethread.restype = c_void_p | |
restorethread = pythonapi.PyEval_RestoreThread | |
restorethread.argtypes = [c_void_p] | |
restorethread.restype = None | |
def test_inner_func(result, a, b, c): | |
threadstate = savethread() | |
for i in range(len(result)): | |
result[i] = 2.1 * a[i] + 3.2 * b[i] * b[i] + 4.3 * c[i] * c[i] * c[i] | |
restorethread(threadstate) | |
inner_func_nb = autojit(test_inner_func, nopython=True) | |
test_func = make_singlethread(inner_func_nb) | |
test_func_mt = make_multithread(inner_func_nb, nthreads) | |
def np_nopow(a, b, c): | |
return 2.1 * a + 3.2 * b * b + 4.3 * c * c * c | |
def ne_pow(a, b, c): | |
return ne.evaluate('2.1 * a + 3.2 * b ** 2 + 4.3 * c ** 3') | |
a = np.random.rand(size) | |
b = np.random.rand(size) | |
c = np.random.rand(size) | |
print "one thread" | |
correct = timefunc(None, np_nopow, a, b, c) | |
timefunc(correct, test_func, a, b, c) | |
print "using {} threads".format(nthreads) | |
timefunc(correct, test_func_mt, a, b, c) | |
if ne is not None: | |
timefunc(correct, ne_pow, a, b, c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# results on my core2duo | |
C:\Users\gdm\devel\test_numba>python mt.py | |
one thread | |
np_nopow 361 ms # numpy | |
func 58 ms # numba 1 thread | |
using 2 threads | |
func_mt 49 ms # numba 2 threads | |
ne_pow 92 ms # numexpr 2 threads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment