-
-
Save ericmjl/d249fd604a2cc4901c79e06551e80e5e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# distutils: extra_compile_args = -O2 -w | |
# cython: boundscheck=False, nonecheck = False, wraparound=False, cdivision=True | |
import numpy as np | |
cimport numpy as np | |
import operator as op | |
from cython.parallel import prange | |
from cython import boundscheck, nonecheck, wraparound | |
def csr_dot(rows, cols, B): | |
out = np.zeros_like(B) | |
_csr_dot(rows, cols, B, out) | |
return out | |
def to_sparse_format(dct): | |
rows, cols = zip(*sorted(dct.items(), key=op.itemgetter(0))) | |
rows = np.repeat(rows, map(len, cols)) | |
cols = np.concatenate(cols) | |
return rows.astype('int32'), cols.astype('int32') | |
@nonecheck(False) | |
@wraparound(False) | |
@boundscheck(False) | |
cdef inline void _csr_dot(int[::1] rows, int[::1] cols, double[:,::1] B, double[:,::1] out): | |
cdef int idx, i, j, k | |
for idx in range(rows.shape[0]): | |
i = rows[idx] | |
k = cols[idx] | |
for j in range(B.shape[1]): | |
out[i,j] += B[k,j] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import numpy as np | |
import numpy.random as npr | |
from bindot import to_sparse_format, csr_dot | |
def dct_to_dense(dct): | |
rows, cols = to_sparse_format(dct) | |
out = np.zeros((1 + max(dct.keys()), 1 + max(max(row) for row in dct.values()))) | |
for i, j in zip(rows, cols): | |
out[i,j] = 1 | |
return out | |
if __name__ == "__main__": | |
npr.seed(0) | |
dct = {0: [0, 1, 2], 1: [1, 0, 3], 2: [2, 0, 3], 3: [3, 1, 2], 4: [4, 5], 5: [5, 4]} | |
B = npr.randn(6, 6) | |
rows, cols = to_sparse_format(dct) | |
print np.allclose(csr_dot(rows, cols, B), np.dot(dct_to_dense(dct), B)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from setuptools import setup | |
import numpy as np | |
from Cython.Build import cythonize | |
setup( | |
ext_modules=cythonize('**/*.pyx'), | |
include_dirs=[np.get_include(),], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment