Last active
December 25, 2015 04:18
-
-
Save jnothman/6915908 to your computer and use it in GitHub Desktop.
`csr_row_norms` for scikit-learn, using fused types and typed memory views.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from libc.math cimport sqrt | |
| cimport cython | |
| cimport numpy as np | |
| import numpy as np | |
| ctypedef fused my_fused_type: | |
| cython.short | |
| cython.int | |
| cython.long | |
| cython.float | |
| cython.double | |
| def _csr_row_norms(my_fused_type[:] data, int[:] indptr, bint squared=False): | |
| cdef: | |
| int num_rows = indptr.shape[0] - 1 | |
| double[:] norms | |
| double sum, x | |
| int i, j | |
| norms = out = np.zeros(num_rows, dtype=np.float64) | |
| for i in range(num_rows): | |
| sum = 0. | |
| for j in range(indptr[i], indptr[i + 1]): | |
| x = <double>data[j] | |
| sum += x * x | |
| if not squared: | |
| sum = sqrt(sum) | |
| norms[i] = sum | |
| return out | |
| def csr_row_norms(X, bint squared=False): | |
| return _csr_row_norms(X.data, X.indptr, squared) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment