Created
November 4, 2014 03:55
-
-
Save JonathanRaiman/f2ce5331750da7b2d4e9 to your computer and use it in GitHub Desktop.
Outer product in Cython
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%cython | |
import cython | |
import numpy as np | |
cimport numpy as np | |
from libc.math cimport exp | |
from libc.string cimport memset | |
from cpython cimport PyCapsule_GetPointer # PyCObject_AsVoidPtr | |
from scipy.linalg.blas import fblas | |
REAL = np.float32 | |
ctypedef np.float32_t REAL_t | |
ctypedef np.uint32_t INT_t | |
cdef int ONE = 1 | |
cdef REAL_t ONEF = <REAL_t>1.0 | |
ctypedef void (*sger_ptr) (const int *M, const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY, float *A, const int * LDA) nogil | |
cdef sger_ptr sger=<sger_ptr>PyCapsule_GetPointer(fblas.sger._cpointer , NULL) # A := alpha*x*y.T + A | |
def outer_prod(_x, _y, _output): | |
cdef REAL_t *x = <REAL_t *>(np.PyArray_DATA(_x)) | |
cdef int M = _y.shape[0] | |
cdef int N = _x.shape[0] | |
cdef REAL_t *y = <REAL_t *>(np.PyArray_DATA(_y)) | |
cdef REAL_t *output = <REAL_t *>(np.PyArray_DATA(_output)) | |
sger(&M, &N, &ONEF, y, &ONE, x, &ONE, output, &M) | |
a = np.arange(0, 3).astype(REAL) | |
b = np.arange(0, 5).astype(REAL) | |
%timeit np.outer(a, b) # 8.16 µs | |
%%timeit | |
output = np.zeros([3,5], dtype=REAL) | |
outer_prod(a,b, output) | |
# 1.57 µs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment