Created
August 9, 2021 17:42
-
-
Save jacobkimmel/b566f03b0c8a2328b3b549cf181f8264 to your computer and use it in GitHub Desktop.
Load dense count arrays as `scipy.sparse` matrices
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gzip | |
import scipy.sparse as sp | |
import numpy as np | |
def load_dense_csv_as_sparse_matrix( | |
file_name: str, | |
dtype=np.int32, | |
skip_lines: int=0, | |
batch_size: int=2048, | |
) -> sp.csr_matrix: | |
""" | |
Load a CSV containing a sparse matrix encoded as a dense matrix | |
to a `sparse.csr_matrix` | |
Parameters | |
---------- | |
file_name : str | |
path to the dense CSV representation of the sparse matrix. | |
dtype : Any | |
valid `dtype` keyword argument to `np.array`. | |
skip_lines : int | |
number of header lines to skip. | |
batch_size : int | |
number of lines to load as dense arrays in memory before | |
casting to sparse. | |
Returns | |
------- | |
sparse_X : scipy.sparse.csr_matrix | |
[rows, columns] sparse matrix. | |
row_names : list | |
any non-numeric row names that were extracted from the first column. | |
Notes | |
----- | |
If the first column is non-numeric, extracts these values as `row_names` | |
and returns as a list. | |
""" | |
# we'll store minibatches of sparse observations as csr_matrices in this | |
# list | |
sparse_batches = [] | |
batch_idx = 0 | |
open_fn = open if ".gz" != file_name[-3:] else gzip.open | |
row_names = [] | |
print(f"reading batches: {batch_idx}", end="\r") | |
with open_fn(file_name, "r") as f: | |
# for each batch, we'll store the row observations as dense | |
# numpy arrays in this list | |
batch_lines = [] | |
for line in f: | |
if skip_lines > 0: | |
skip_lines = skip_lines - 1 | |
continue | |
# convert the line to a numpy dense array | |
line = line.decode("utf-8") | |
l_split = line.split("\t") | |
if not l_split[0].isnumeric(): | |
row_names.append(l_split[0]) | |
del l_split[0] | |
l_array = np.array(l_split, dtype=dtype) | |
if batch_idx < batch_size: | |
# add the array to your running list of lines | |
batch_lines.append(l_array) | |
batch_idx += 1 | |
else: | |
# create a csr_matrix from the dense arrays | |
A = sp.csr_matrix(np.stack(batch_lines, axis=0)) | |
sparse_batches.append(A) | |
# reset the batch counters and add the current line | |
batch_lines = [] | |
batch_lines.append(l_array,) | |
batch_idx = 1 | |
print(f"reading batches: {batch_idx:04d}", end="\r") | |
# add the final batch to the running batch matrix list | |
A = sp.csr_matrix(np.stack(batch_lines, axis=0)) | |
sparse_batches.append(A) | |
# stack the sparse batches to form a sparse matrix | |
sparse_X = sp.vstack(sparse_batches) | |
return sparse_X, row_names |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment