Skip to content

Instantly share code, notes, and snippets.

@zed
Created March 16, 2012 18:25
Show Gist options
  • Save zed/2051661 to your computer and use it in GitHub Desktop.
Save zed/2051661 to your computer and use it in GitHub Desktop.
Naive O(N**3) 2D np.dot() multithreaded implementation (CPython extension in Cython)
#cython: boundscheck=False, wraparound=False
import numpy as np
cimport numpy as np
from cython.parallel cimport prange
def dot(np.ndarray[np.float32_t, ndim=2] a not None,
np.ndarray[np.float32_t, ndim=2] b not None,
np.ndarray[np.float32_t, ndim=2] out=None):
"""Naive O(N**3) 2D np.dot() implementation."""
if out is None:
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
if (a.shape[1] != b.shape[0] or
out.shape[0] != a.shape[0] or out.shape[1] != b.shape[1]):
raise ValueError("wrong shape")
cdef Py_ssize_t i, j, k
with nogil:
for i in prange(a.shape[0]):
for j in range(b.shape[1]):
out[i,j] = 0
for k in range(a.shape[1]):
out[i,j] += a[i,k] * b[k,j]
return out
from distutils.extension import Extension
def make_ext(modname, pyxfilename):
return Extension(name=modname,
sources=[pyxfilename],
extra_compile_args=['-fopenmp'],
extra_link_args=['-fopenmp'])

Without prange() (single-threaded):

python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
10 loops, best of 3: 119 msec per loop

With prange() (number of threads == number of cores):

python -mtimeit -s'from test_cydot import a,b,out,cydot' 'cydot.dot(a,b,out)'
10 loops, best of 3: 69.9 msec per loop

numpy.dot() version for comparison:

python -mtimeit -s'from test_cydot import a,b,out,np' 'np.dot(a,b,out)'
100 loops, best of 3: 9.97 msec per loop
import pyximport; pyximport.install() # pip install cython
import numpy as np
import cydot
a = np.random.rand(50, 10000).astype(np.float32)
b = np.random.rand(10000, 60).astype(np.float32)
out = np.empty((a.shape[0], b.shape[1]), dtype=a.dtype)
def test():
assert np.allclose(np.dot(a,b), cydot.dot(a,b))
out2 = out.copy()
out[:] = -1; out2[:] = -2
assert np.allclose(out, -1) and np.allclose(out2, -2)
np.dot(a, b, out); cydot.dot(a, b, out2)
assert np.allclose(out, out2), (out,out2)
if __name__=="__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment