Created
May 18, 2020 17:11
-
-
Save wush978/3a0e02b64c554546868402a517cc3c92 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%cython --annotate --cplus --compile-args=-fopenmp --link-args=-fopenmp | |
cimport cython | |
from cython.parallel cimport prange, parallel, threadid | |
import numpy as np | |
cimport numpy as np | |
import scipy | |
cimport openmp | |
#cdef extern from "<algorithm>" namespace "std": | |
# cdef void sort[RI](RI first, RI last) | |
cdef extern from "<parallel/algorithm>" namespace "__gnu_parallel": | |
cdef void sort[RI](RI first, RI last) except + | |
cdef void unique_copy[II, OI](II first, II last, OI d_first) except + | |
from libcpp.vector cimport vector | |
from libcpp.utility cimport pair | |
ctypedef pair[np.int64_t,np.int64_t] Index | |
cdef void __csr_to_csc( | |
np.int64_t *src_indptr, | |
np.int64_t *src_indices, | |
np.int64_t *dst_indptr, | |
np.int64_t *dst_indices, | |
size_t nrow, | |
size_t ncol, | |
size_t nnz, | |
): | |
cdef vector[Index] index | |
cdef vector[vector[np.int64_t]] buffer = vector[vector[np.int64_t]](openmp.omp_get_max_threads()) | |
cdef size_t i, j, nthread | |
with nogil: | |
index.resize(nnz) | |
for i in prange(nrow): | |
for j in range(src_indptr[i],src_indptr[i+1]): | |
index[j].second = i # row | |
index[j].first = src_indices[j] # col | |
sort[vector[Index].iterator](index.begin(), index.end()) | |
with nogil, parallel(): | |
buffer[threadid()].resize(ncol) | |
for i in prange(nnz): | |
buffer[threadid()][index[i].first] += 1 | |
dst_indices[i] = index[i].second | |
for i in prange(ncol): | |
for j in range(buffer.size()): | |
if buffer[j].size() > 0: | |
dst_indptr[i+1] += buffer[j][i] | |
for i in range(ncol): | |
dst_indptr[i+1] = dst_indptr[i+1] + dst_indptr[i] | |
cdef np.int64_t* getp(np.ndarray[np.int64_t, ndim = 1] arr): | |
return &arr[0] | |
def csr_to_csc(m): | |
if not type(m) is scipy.sparse.csr.csr_matrix: | |
raise RuntimeError("m is not a csr_matrix") | |
if not m.indptr.dtype == np.int64: | |
raise RuntimeError("The indptr is not int64") | |
assert(m.indices.dtype == np.int64) | |
if not np.all(m.data == 1): | |
raise RuntimeError("The data is not all 1") | |
dst_indptr = np.zeros(m.shape[1] + 1, dtype = np.int64) | |
dst_indices = np.zeros(len(m.indices), dtype = np.int64) | |
__csr_to_csc( | |
getp(m.indptr), | |
getp(m.indices), | |
getp(dst_indptr), | |
getp(dst_indices), | |
m.shape[0], | |
m.shape[1], | |
len(m.indices), | |
) | |
return scipy.sparse.csc_matrix((m.data, dst_indices, dst_indptr), shape = m.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment