Last active
March 13, 2018 14:54
-
-
Save amorgun/932312022af498264ca9235668a676cb to your computer and use it in GitHub Desktop.
Efficient sparse csr matrix hstack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy as sp | |
import scipy.sparse | |
import tempfile | |
def hstack(parts): | |
with tempfile.TemporaryFile() as data_file, tempfile.TemporaryFile() as indices_file: | |
data = np.memmap(data_file, | |
dtype=parts[0].dtype, | |
shape=sum(p.data.shape[0] for p in parts)) | |
indices = np.memmap(indices_file, | |
dtype=parts[0].dtype, | |
shape=sum(p.indices.shape[0] for p in parts)) | |
data_offset = 0 | |
for row_idx in range(parts[0].shape[0]): | |
position = 0 | |
for part in parts: | |
start_idx, end_idx = part.indptr[row_idx], part.indptr[row_idx+1] | |
block_len = end_idx - start_idx | |
data[data_offset:data_offset + block_len] = part.data[start_idx: end_idx] | |
indices[data_offset:data_offset + block_len] = part.indices[start_idx: end_idx] + position | |
data_offset += block_len | |
position += part.shape[1] | |
result_shape = (parts[0].shape[0], sum(p.shape[1] for p in parts)) | |
indptr_parts_len = np.sum( | |
[p.indptr[1:] - p.indptr[:-1] for p in parts], | |
axis=0, | |
) | |
indptr = np.zeros(indptr_parts_len.shape[0] + 1) | |
np.cumsum(indptr_parts_len, out=indptr[1:]) | |
del parts # из функции нельзя нормально удалить её аргумент, надо занилайнить эту реализацию hstack и удалить здесь ссылки на изначальные массивы, чтобы освободить память | |
return sp.sparse.csr_matrix( | |
(np.asarray(data), np.asarray(indices), indptr), | |
shape=result_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy as sp | |
import scipy.sparse | |
def hstack(parts): | |
data = np.concatenate([p.data for p in parts]) | |
indices_parts, position = [], 0 | |
for part in parts: | |
indices_parts.append(part.indices + position) | |
position += part.shape[0] | |
indices = np.concatenate(indices_parts) | |
indptr = np.array([0, data.shape[0]]) | |
return sp.sparse.csr_matrix((data, indices, indptr), | |
shape=(1, sum(p.shape[1] for p in parts))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment