Created
August 3, 2020 20:55
-
-
Save ajtritt/e2ef0f2aaecb255ca89f82817035221f to your computer and use it in GitHub Desktop.
Concatenate one or more .npy files into a dataset into an HDF5 file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import sys | |
import h5py | |
import numpy as np | |
desc = ''' | |
Concatenate one or more .npy files into a dataset into an HDF5 file | |
''' | |
parser = argparse.ArgumentParser(description=desc) | |
parser.add_argument('output_h5', help='the .h5 file to write converted data to') | |
parser.add_argument('dset_name', help='the dataset name in output_h5') | |
parser.add_argument('input_npy', nargs='+', help='the .npy file with the data to convert') | |
parser.add_argument('-D', '--cat_dim', type=int, default=0, | |
help='the dimension to concatenate along. default is to concatenate along first dimension',) | |
parser.add_argument('-F', '--force', action='store_true', default=False, help='force write') | |
if len(sys.argv) == 1: | |
parser.print_help() | |
args = parser.parse_args() | |
# calculate the new length of the dataset | |
shape = None | |
dtype = None | |
dset_len = 0 | |
for npy_path in args.input_npy: | |
print(f'getting shape for {npy_path}') | |
ar = np.load(npy_path, mmap_mode='r') | |
# here we assume that each ndarray has the same shape | |
# along other dimensions that we aren't concatening on | |
shape = ar.shape | |
dtype = ar.dtype | |
dset_len += ar.shape[args.cat_dim] | |
with h5py.File(args.output_h5, 'a') as f: | |
if args.dset_name in f: | |
if args.force: | |
del f[args.dset_name] | |
else: | |
print(f'{args.dset_name} already exists. Please use a different dset_name or override behavior with -F') | |
sys.exit(1) | |
shape = list(shape) | |
shape[args.cat_dim] = dset_len | |
dset = f.create_dataset(args.dset_name, shape=tuple(shape), dtype=dtype) | |
st = 0 | |
end = 0 | |
sl = [np.s_[:] for i in range(len(shape))] | |
for npy_path in args.input_npy: | |
print(f'reading {npy_path}') | |
ar = np.load(npy_path) | |
end += ar.shape[args.cat_dim] | |
print(f'writing {npy_path}') | |
ar = np.load(npy_path) | |
sl[args.cat_dim] = np.s_[st:end] | |
dset[tuple(sl)] = ar | |
st = end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment