-
-
Save asw456/5793403 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!python | |
#cython: boundscheck=False | |
#cython: wraparound=False | |
#cython: cdivision=True | |
import numpy as np | |
cimport numpy as np | |
from libc.math cimport sqrt | |
cdef extern from "cblas.h": | |
enum CBLAS_ORDER: CblasRowMajor, CblasColMajor | |
enum CBLAS_TRANSPOSE: CblasNoTrans, CblasTrans, CblasConjTrans | |
void lib_dgemm "cblas_dgemm"(CBLAS_ORDER Order, CBLAS_TRANSPOSE TransA, | |
CBLAS_TRANSPOSE TransB, int M, int N, int K, | |
double alpha, double *A, int lda, double *B, int ldb, | |
double beta, double *C, int ldc) nogil | |
def pairwise_cython_blas(double[:, ::1] X): | |
cdef: | |
int M = X.shape[0] | |
int N = X.shape[1] | |
unsigned int i, j, k | |
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64) | |
double[:, ::1] C | |
C = _C | |
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | |
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1], | |
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1]) | |
for i in range(M-1): | |
C[i,i] = 0.0 | |
for j in range(i+1,M): | |
for k in range(N): | |
C[i,j] += (X[i,k]**2 + X[j,k]**2) | |
C[i,j] = sqrt(C[i,j]) | |
C[j,i] = C[i,j] | |
C[M-1,M-1] = 0.0 | |
return _C | |
def pairwise_cython_blas2(double[:, ::1] X): | |
cdef: | |
int M = X.shape[0] | |
int N = X.shape[1] | |
unsigned int i, j, k | |
np.ndarray[double, ndim=2] _C = np.zeros((M, M), dtype=np.float64) | |
double[:, ::1] C | |
double[::1] sx = np.empty((M,), dtype=np.float64) | |
C = _C | |
lib_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, | |
C.shape[0], C.shape[0], X.shape[1], -2.0, &X[0,0], X.shape[1], | |
&X[0,0], X.shape[1], 1.0, &C[0,0], C.shape[1]) | |
for i in range(M): | |
sx[i] = 0.0 | |
for k in range(N): | |
sx[i] += X[i,k]**2 | |
for i in range(M-1): | |
C[i,i] = 0.0 | |
for j in range(i+1,M): | |
C[i,j] += (sx[i] + sx[j]) | |
C[i,j] = sqrt(C[i,j]) | |
C[j,i] = C[i,j] | |
C[M-1,M-1] = 0.0 | |
return _C | |
def pairwise_cython(double[:, ::1] X): | |
cdef int M = X.shape[0] | |
cdef int N = X.shape[1] | |
cdef double tmp, d | |
cdef double[:, ::1] D = np.empty((M, M), dtype=np.float64) | |
for i in range(M): | |
for j in range(M): | |
d = 0.0 | |
for k in range(N): | |
tmp = X[i, k] - X[j, k] | |
d += tmp * tmp | |
D[i, j] = sqrt(d) | |
return np.asarray(D) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
import numpy as np | |
ext_params = {} | |
ext_params['include_dirs'] = [ | |
'/usr/include', | |
'/System/Library/Frameworks/vecLib.framework/Versions/A/Headers', | |
np.get_include()] | |
ext_params['extra_compile_args'] = ["-O2"] | |
ext_params['extra_link_args'] = [] | |
ext_params['libraries'] = ['blas'] | |
ext_params['library_dirs'] = ['/usr/lib'] | |
ext_modules = [ | |
Extension("distlib", ["distlib.pyx"], **ext_params), | |
] | |
setup( | |
name='distlib', | |
cmdclass={'build_ext': build_ext}, | |
ext_modules=ext_modules, | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In [1]: import numpy as np | |
In [2]: from scipy.spatial.distance import cdist | |
In [3]: from distlib import pairwise_cython_blas, pairwise_cython | |
In [4]: a = np.random.random(size=(1000,3)) | |
In [5]: %timeit cdist(a,a) | |
100 loops, best of 3: 11.3 ms per loop | |
In [6]: %timeit pairwise_cython(a) | |
100 loops, best of 3: 9.54 ms per loop | |
In [7]: %timeit pairwise_cython_blas(a) | |
100 loops, best of 3: 13.6 ms per loop | |
In [8]: %timeit pairwise_cython_blas2(a) | |
100 loops, best of 3: 13.3 ms per loop |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment