Last active
August 29, 2015 14:08
-
-
Save JonathanRaiman/07046b897709fffb49e5 to your computer and use it in GitHub Desktop.
Cython & BLAS gemm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to get gemm to work in Cython | |
# 1. suppose your data is Fortran contiguous (really ?) | |
# then blas supports this out of the box: | |
%%cython | |
cimport numpy as np | |
import numpy as np | |
from cpython cimport PyCapsule_GetPointer # PyCObject_AsVoidPtr | |
from scipy.linalg.blas import fblas | |
from cython_lstm import vector_outer_product | |
cdef int ONE = 1 | |
cdef float ONE_f = 1.0 | |
cdef float ZERO_f = 0.0 | |
REAL = np.float32 | |
ctypedef np.float32_t REAL_t | |
# check for fortan here: | |
cdef extern from "numpy/arrayobject.h": | |
cdef bint PyArray_IS_F_CONTIGUOUS(np.ndarray) nogil | |
# create the pointers to the BLAS functions | |
ctypedef void (*sgemm_ptr) (char *transA, char *transB, \ | |
int *m, int *n, int *k,\ | |
float *alpha,\ | |
float *a, int *lda,\ | |
float *b, int *ldb,\ | |
float *beta, \ | |
float *c, int *ldc) | |
cdef sgemm_ptr sgemm=<sgemm_ptr>PyCapsule_GetPointer(fblas.sgemm._cpointer, NULL) | |
# with those pointers we can now wrap a cython function for these | |
def dot(np.ndarray[REAL_t, ndim=2] _a, np.ndarray[REAL_t, ndim=2] _b, | |
REAL_t alpha=1., REAL_t beta=0.): | |
cdef int m, n, k, lda, ldb, ldc | |
cdef char * transA = &trans | |
cdef char * transB = &trans | |
cdef REAL_t * a | |
cdef REAL_t * b | |
cdef REAL_t * c | |
if PyArray_IS_F_CONTIGUOUS(_a): | |
transA = &n_trans | |
else: | |
# when not fortran we can also transpose | |
# to coerce matrix into fortran | |
_a = _a.T | |
if PyArray_IS_F_CONTIGUOUS(_b): | |
transB = &n_trans | |
else: | |
# when not fortran we can also transpose | |
# to coerce matrix into fortran | |
_b = _b.T | |
if transA[0] == n_trans: | |
m = _a.shape[0] | |
k = _a.shape[1] | |
n = _b.shape[1] | |
else: | |
m = _a.shape[1] | |
k = _a.shape[0] | |
n = _b.shape[0] | |
cdef np.ndarray[REAL_t, ndim=2] _c = np.zeros((m,n), dtype=REAL, order="F") | |
a = <REAL_t *>np.PyArray_DATA(_a) | |
b = <REAL_t *>np.PyArray_DATA(_b) | |
c = <REAL_t *>np.PyArray_DATA(_c) | |
with nogil: | |
# some of the operations above | |
# are gil needy and thus | |
# only this last chunk can be "ungiled" | |
# when life give you lemons make lemonade | |
lda = _a.shape[0] | |
ldb = _b.shape[0] | |
ldc = _c.shape[0] | |
sgemm(transA, transB, &m, &n, &k, &alpha, &a[0], &lda, &b[0], &ldb, | |
&beta, &c[0], &ldc) | |
return _c | |
# we can test this: | |
def fortran_arrays(x,y): | |
return (np.asfortranarray(np.random.randn(x,y).astype(np.float32)), | |
np.asfortranarray(np.random.randn(y,x).astype(np.float32))) | |
def ordinary_arrays(x,y): | |
return (np.random.randn(x,y).astype(np.float32), | |
np.random.randn(y,x).astype(np.float32)) | |
def test_dot() | |
fworks = [] | |
fdoesnt = 0 | |
works = [] | |
doesnt = 0 | |
for i in range(1, 10): | |
for j in range(1, 10): | |
a, b = fortran_arrays(i,j) | |
c = dot(a,b) | |
d = np.dot(a,b) | |
try: | |
if np.allclose(c, d): | |
fworks.append((i,j)) | |
else: | |
fdoesnt += 1 | |
except (TypeError, ValueError): | |
doesnt += 1 | |
pass | |
a, b = ordinary_arrays(i,j) | |
c = dot(a,b) | |
d = np.dot(a,b) | |
try: | |
if np.allclose(c, d): | |
works.append((i,j)) | |
else: | |
doesnt += 1 | |
except (TypeError, ValueError): | |
doesnt += 1 | |
pass | |
return (len(fworks), fdoesnt, len(works), doesnt) | |
test_dot() | |
# => (81, 0, 81, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment