-
-
Save t-bltg/65fdee53617fbf68a49e77980d813808 to your computer and use it in GitHub Desktop.
Parallel Sparse Matrix Dense Matrix Product in C/Cython/Python.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef _CS_H | |
#define _CS_H | |
#include <stdlib.h> | |
// #include <stdint.h> | |
#include <limits.h> | |
#include <math.h> | |
#include <stdio.h> | |
#include <stddef.h> | |
#ifdef MATLAB_MEX_FILE | |
#include "mex.h" | |
#endif | |
#define CS_VER 3 /* CSparse Version */ | |
#define CS_SUBVER 1 | |
#define CS_SUBSUB 2 | |
#define CS_DATE "April 16, 2013" /* CSparse release date */ | |
#define CS_COPYRIGHT "Copyright (c) Timothy A. Davis, 2006-2013" | |
#ifdef MATLAB_MEX_FILE | |
#undef csi | |
#define csi mwSignedIndex | |
#endif | |
// #ifndef csi | |
// #define csi ptrdiff_t | |
// #endif | |
// FORCE use of 32 bit int offsets, because thats what scipy uses so we can | |
// avoid a copy. | |
#define csi int32_t | |
csi cs_gaxpy (const cs *A, const double *x, double *y) ; | |
/* --- primary CSparse routines and data structures ------------------------- */ | |
typedef struct cs_sparse /* matrix in compressed-column or triplet form */ | |
{ | |
csi nzmax ; /* maximum number of entries */ | |
csi m ; /* number of rows */ | |
csi n ; /* number of columns */ | |
csi *p ; /* column pointers (size n+1) or col indices (size nzmax) */ | |
csi *i ; /* row indices, size nzmax */ | |
double *x ; /* numerical values, size nzmax */ | |
csi nz ; /* # of entries in triplet matrix, -1 for compressed-col */ | |
} cs ; | |
cs *cs_add (const cs *A, const cs *B, double alpha, double beta) ; | |
csi cs_cholsol (csi order, const cs *A, double *b) ; | |
cs *cs_compress (const cs *T) ; | |
csi cs_dupl (cs *A) ; | |
csi cs_entry (cs *T, csi i, csi j, double x) ; | |
cs *cs_load (FILE *f) ; | |
csi cs_lusol (csi order, const cs *A, double *b, double tol) ; | |
cs *cs_multiply (const cs *A, const cs *B) ; | |
double cs_norm (const cs *A) ; | |
csi cs_print (const cs *A, csi brief) ; | |
csi cs_qrsol (csi order, const cs *A, double *b) ; | |
cs *cs_transpose (const cs *A, csi values) ; | |
/* utilities */ | |
void *cs_calloc (csi n, size_t size) ; | |
void *cs_free (void *p) ; | |
void *cs_realloc (void *p, csi n, size_t size, csi *ok) ; | |
cs *cs_spalloc (csi m, csi n, csi nzmax, csi values, csi triplet) ; | |
cs *cs_spfree (cs *A) ; | |
csi cs_sprealloc (cs *A, csi nzmax) ; | |
void *cs_malloc (csi n, size_t size) ; | |
/* --- secondary CSparse routines and data structures ----------------------- */ | |
typedef struct cs_symbolic /* symbolic Cholesky, LU, or QR analysis */ | |
{ | |
csi *pinv ; /* inverse row perm. for QR, fill red. perm for Chol */ | |
csi *q ; /* fill-reducing column permutation for LU and QR */ | |
csi *parent ; /* elimination tree for Cholesky and QR */ | |
csi *cp ; /* column pointers for Cholesky, row counts for QR */ | |
csi *leftmost ; /* leftmost[i] = min(find(A(i,:))), for QR */ | |
csi m2 ; /* # of rows for QR, after adding fictitious rows */ | |
double lnz ; /* # entries in L for LU or Cholesky; in V for QR */ | |
double unz ; /* # entries in U for LU; in R for QR */ | |
} css ; | |
typedef struct cs_numeric /* numeric Cholesky, LU, or QR factorization */ | |
{ | |
cs *L ; /* L for LU and Cholesky, V for QR */ | |
cs *U ; /* U for LU, R for QR, not used for Cholesky */ | |
csi *pinv ; /* partial pivoting for LU */ | |
double *B ; /* beta [0..n-1] for QR */ | |
} csn ; | |
typedef struct cs_dmperm_results /* cs_dmperm or cs_scc output */ | |
{ | |
csi *p ; /* size m, row permutation */ | |
csi *q ; /* size n, column permutation */ | |
csi *r ; /* size nb+1, block k is rows r[k] to r[k+1]-1 in A(p,q) */ | |
csi *s ; /* size nb+1, block k is cols s[k] to s[k+1]-1 in A(p,q) */ | |
csi nb ; /* # of blocks in fine dmperm decomposition */ | |
csi rr [5] ; /* coarse row decomposition */ | |
csi cc [5] ; /* coarse column decomposition */ | |
} csd ; | |
csi *cs_amd (csi order, const cs *A) ; | |
csn *cs_chol (const cs *A, const css *S) ; | |
csd *cs_dmperm (const cs *A, csi seed) ; | |
csi cs_droptol (cs *A, double tol) ; | |
csi cs_dropzeros (cs *A) ; | |
csi cs_happly (const cs *V, csi i, double beta, double *x) ; | |
csi cs_ipvec (const csi *p, const double *b, double *x, csi n) ; | |
csi cs_lsolve (const cs *L, double *x) ; | |
csi cs_ltsolve (const cs *L, double *x) ; | |
csn *cs_lu (const cs *A, const css *S, double tol) ; | |
cs *cs_permute (const cs *A, const csi *pinv, const csi *q, csi values) ; | |
csi *cs_pinv (const csi *p, csi n) ; | |
csi cs_pvec (const csi *p, const double *b, double *x, csi n) ; | |
csn *cs_qr (const cs *A, const css *S) ; | |
css *cs_schol (csi order, const cs *A) ; | |
css *cs_sqr (csi order, const cs *A, csi qr) ; | |
cs *cs_symperm (const cs *A, const csi *pinv, csi values) ; | |
csi cs_updown (cs *L, csi sigma, const cs *C, const csi *parent) ; | |
csi cs_usolve (const cs *U, double *x) ; | |
csi cs_utsolve (const cs *U, double *x) ; | |
/* utilities */ | |
css *cs_sfree (css *S) ; | |
csn *cs_nfree (csn *N) ; | |
csd *cs_dfree (csd *D) ; | |
/* --- tertiary CSparse routines -------------------------------------------- */ | |
csi *cs_counts (const cs *A, const csi *parent, const csi *post, csi ata) ; | |
double cs_cumsum (csi *p, csi *c, csi n) ; | |
csi cs_dfs (csi j, cs *G, csi top, csi *xi, csi *pstack, const csi *pinv) ; | |
csi cs_ereach (const cs *A, csi k, const csi *parent, csi *s, csi *w) ; | |
csi *cs_etree (const cs *A, csi ata) ; | |
csi cs_fkeep (cs *A, csi (*fkeep) (csi, csi, double, void *), void *other) ; | |
double cs_house (double *x, double *beta, csi n) ; | |
csi cs_leaf (csi i, csi j, const csi *first, csi *maxfirst, csi *prevleaf, | |
csi *ancestor, csi *jleaf) ; | |
csi *cs_maxtrans (const cs *A, csi seed) ; | |
csi *cs_post (const csi *parent, csi n) ; | |
csi *cs_randperm (csi n, csi seed) ; | |
csi cs_reach (cs *G, const cs *B, csi k, csi *xi, const csi *pinv) ; | |
csi cs_scatter (const cs *A, csi j, double beta, csi *w, double *x, csi mark, | |
cs *C, csi nz) ; | |
csd *cs_scc (cs *A) ; | |
csi cs_spsolve (cs *G, const cs *B, csi k, csi *xi, double *x, | |
const csi *pinv, csi lo) ; | |
csi cs_tdfs (csi j, csi k, csi *head, const csi *next, csi *post, | |
csi *stack) ; | |
/* utilities */ | |
csd *cs_dalloc (csi m, csi n) ; | |
csd *cs_ddone (csd *D, cs *C, void *w, csi ok) ; | |
cs *cs_done (cs *C, void *w, void *x, csi ok) ; | |
csi *cs_idone (csi *p, cs *C, void *w, csi ok) ; | |
csn *cs_ndone (csn *N, cs *C, void *w, void *x, csi ok) ; | |
#define CS_MAX(a,b) (((a) > (b)) ? (a) : (b)) | |
#define CS_MIN(a,b) (((a) < (b)) ? (a) : (b)) | |
#define CS_FLIP(i) (-(i)-2) | |
#define CS_UNFLIP(i) (((i) < 0) ? CS_FLIP(i) : (i)) | |
#define CS_MARKED(w,j) (w [j] < 0) | |
#define CS_MARK(w,j) { w [j] = CS_FLIP (w [j]) ; } | |
#define CS_CSC(A) (A && (A->nz == -1)) | |
#define CS_TRIPLET(A) (A && (A->nz >= 0)) | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "cs.h" | |
/* y = A*x+y */ | |
csi cs_gaxpy (const cs *A, const double *x, double *y) | |
{ | |
csi p, j, n, *Ap, *Ai ; | |
double *Ax ; | |
if (!CS_CSC (A) || !x || !y) return (0) ; /* check inputs */ | |
n = A->n ; Ap = A->p ; Ai = A->i ; Ax = A->x ; | |
for (j = 0 ; j < n ; j++) | |
{ | |
for (p = Ap [j] ; p < Ap [j+1] ; p++) | |
{ | |
y [Ai [p]] += Ax [p] * x [j] ; | |
} | |
} | |
return (1) ; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import os | |
import time | |
import numpy as np | |
import scipy.sparse | |
from _psparse import pmultiply | |
n_trials = 10 | |
N, M, P = 1000, 10000, 100 | |
RHO = 0.1 | |
X = scipy.sparse.rand(N, M, RHO).tocsc() | |
W = np.asfortranarray(np.random.randn(M, P)) | |
assert np.all(pmultiply(X, W) == X.dot(W)) | |
t0 = time.time() | |
for i in range(n_trials): | |
A = pmultiply(X, W) | |
t1 = time.time() | |
for i in range(n_trials): | |
B = X.dot(W) | |
t2 = time.time() | |
print 'This Code : %.5fs' % ((t1 - t0) / n_trials) | |
print 'Scipy : %.5fs' % ((t2 - t1) / n_trials) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#----------------------------------------------------------------------------- | |
# Imports | |
#----------------------------------------------------------------------------- | |
cimport cython | |
cimport numpy as np | |
import numpy as np | |
import scipy.sparse | |
from libc.stddef cimport ptrdiff_t | |
from cython.parallel import parallel, prange | |
#----------------------------------------------------------------------------- | |
# Headers | |
#----------------------------------------------------------------------------- | |
ctypedef int csi | |
ctypedef struct cs: | |
# matrix in compressed-column or triplet form | |
csi nzmax # maximum number of entries | |
csi m # number of rows | |
csi n # number of columns | |
csi *p # column pointers (size n+1) or col indices (size nzmax) | |
csi *i # row indices, size nzmax | |
double *x # numerical values, size nzmax | |
csi nz # # of entries in triplet matrix, -1 for compressed-col | |
cdef extern csi cs_gaxpy (cs *A, double *x, double *y) nogil | |
cdef extern csi cs_print (cs *A, csi brief) nogil | |
assert sizeof(csi) == 4 | |
#----------------------------------------------------------------------------- | |
# Functions | |
#----------------------------------------------------------------------------- | |
@cython.boundscheck(False) | |
def pmultiply(X not None, np.ndarray[ndim=2, mode='fortran', dtype=np.float64_t] W not None): | |
"""Multiply a sparse CSC matrix by a dense matrix | |
Parameters | |
---------- | |
X : scipy.sparse.csc_matrix | |
A sparse matrix, of size N x M | |
W : np.ndarray[dtype=float564, ndim=2, mode='fortran'] | |
A dense matrix, of size M x P. Note, W must be contiguous and in | |
fortran (column-major) order. You can ensure this using | |
numpy's `asfortranarray` function. | |
Returns | |
------- | |
A : np.ndarray[dtype=float64, ndim=2, mode='fortran'] | |
A dense matrix, of size N x P, the result of multiplying X by W. | |
Notes | |
----- | |
This function is parallelized over the columns of W using OpenMP. You | |
can control the number of threads at runtime using the OMP_NUM_THREADS | |
environment variable. The internal sparse matrix code is from CSPARSE, | |
a Concise Sparse matrix package. Copyright (c) 2006, Timothy A. Davis. | |
http://www.cise.ufl.edu/research/sparse/CSparse, licensed under the | |
GNU LGPL v2.1+. | |
References | |
---------- | |
.. [1] Davis, Timothy A., "Direct Methods for Sparse Linear Systems | |
(Fundamentals of Algorithms 2)," SIAM Press, 2006. ISBN: 0898716136 | |
""" | |
if X.shape[1] != W.shape[0]: | |
raise ValueError('matrices are not aligned') | |
cdef int i | |
cdef cs csX | |
cdef np.ndarray[double, ndim=2, mode='fortran'] result | |
cdef np.ndarray[csi, ndim=1, mode = 'c'] indptr = X.indptr | |
cdef np.ndarray[csi, ndim=1, mode = 'c'] indices = X.indices | |
cdef np.ndarray[double, ndim=1, mode = 'c'] data = X.data | |
# Pack the scipy data into the CSparse struct. This is just copying some | |
# pointers. | |
csX.nzmax = X.data.shape[0] | |
csX.m = X.shape[0] | |
csX.n = X.shape[1] | |
csX.p = &indptr[0] | |
csX.i = &indices[0] | |
csX.x = &data[0] | |
csX.nz = -1 # to indicate CSC format | |
result = np.zeros((X.shape[0], W.shape[1]), order='F', dtype=np.double) | |
for i in prange(W.shape[1], nogil=True): | |
# X is in fortran format, so we can get quick access to each of its | |
# columns | |
cs_gaxpy(&csX, &W[0, i], &result[0, i]) | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import glob | |
from distutils.core import setup | |
from distutils.extension import Extension | |
from Cython.Distutils import build_ext | |
ext = Extension('_psparse', | |
['psparse.pyx', 'cs_gaxpy.c'], | |
extra_compile_args=['-fopenmp', '-O3', '-ffast-math'], | |
include_dirs = [np.get_include(), '.'], | |
extra_link_args=['-fopenmp']) | |
setup( | |
cmdclass = {'build_ext': build_ext}, | |
py_modules = ['psparse',], | |
ext_modules = [ext] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment