Created
March 2, 2018 20:58
-
-
Save bkj/f448025fdef08c0609029489fa26ea2a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
class H5Dataset(Dataset): | |
def __init__(self, h5_path): | |
self.h5_path = h5_path | |
self.h5_file = h5py.File(h5_path, 'r') | |
self.length = len(h5py.File(h5_path, 'r')) | |
def __getitem__(self, index): | |
record = self.h5_file[str(index)] | |
return ( | |
record['data'].value, | |
record['target'].value, | |
) | |
def __len__(self): | |
return self.length | |
# -- | |
# Make data | |
f = h5py.File('test.h5') | |
for i in range(256): | |
f['%s/data' % i] = np.random.uniform(0, 1, (1024, 1024)) | |
f['%s/target' % i] = np.random.choice(1000) | |
# Runs correctly | |
dataloader = torch.utils.data.DataLoader( | |
H5Dataset('test.h5'), | |
batch_size=32, | |
num_workers=0, | |
shuffle=True | |
) | |
for i,(data,target) in enumerate(dataloader): | |
print(data.shape) | |
if i > 10: | |
break | |
# Throws error (sometimes, may have to restart python) | |
dataloader = torch.utils.data.DataLoader( | |
H5Dataset('test.h5'), | |
batch_size=32, | |
num_workers=8, | |
shuffle=True | |
) | |
for i,(data,target) in enumerate(dataloader): | |
print(data.shape) | |
if i > 10: | |
break | |
# KeyError: 'Traceback (most recent call last): | |
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop | |
# samples = collate_fn([dataset[i] for i in batch_indices]) | |
# File "<stdin>", line 11, in __getitem__ | |
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper | |
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper | |
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/h5py/_hl/group.py", line 167, in __getitem__ | |
# oid = h5o.open(self.id, self._e(name), lapl=self._lapl) | |
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper | |
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper | |
# File "h5py/h5o.pyx", line 190, in h5py.h5o.open | |
# KeyError: Unable to open object (bad object header version number) |
Hersue
commented
Jan 1, 2020
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment