Created
July 19, 2017 06:08
-
-
Save lihuanshuai/f798919a324a8e4b5b70a5583b6ab462 to your computer and use it in GitHub Desktop.
dump and load sparse matrix https://stackoverflow.com/questions/11129429/storing-numpy-sparse-matrix-in-hdf5-pytables
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tables as tb | |
from numpy import array | |
from scipy import sparse | |
def store_sparse_mat(path, name, m, separator='__'): | |
if (m.__class__ not in [sparse.csr.csr_matrix, sparse.csc.csc_matrix]): | |
raise TypeError("This code only works for csr/csc matrices") | |
with tb.openFile(path, 'a') as f: | |
for par in ('data', 'indices', 'indptr', 'shape'): | |
full_name = '%s%s%s' % (name, separator, par) | |
try: | |
n = getattr(f.root, full_name) | |
n._f_remove() | |
except AttributeError: | |
pass | |
arr = array(getattr(m, par)) | |
atom = tb.Atom.from_dtype(arr.dtype) | |
ds = f.createCArray(f.root, full_name, atom, arr.shape) | |
ds[:] = arr | |
def load_sparse_mat( | |
path, name, type_=sparse.csr.csr_matrix, separator='__', default=None): | |
if (type_ not in [sparse.csr.csr_matrix, sparse.csc.csc_matrix]): | |
raise TypeError("This code only works for csr/csc matrices") | |
if not os.path.isfile(path): | |
return default | |
with tb.openFile(path) as f: | |
pars = [] | |
for par in ('data', 'indices', 'indptr', 'shape'): | |
try: | |
arr = getattr(f.root, '%s%s%s' % (name, separator, par)) | |
except AttributeError: | |
return default | |
r = arr.read() | |
pars.append(r) | |
m = type_(tuple(pars[:3]), shape=pars[3]) | |
return m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment