-
-
Save bkj/f448025fdef08c0609029489fa26ea2a to your computer and use it in GitHub Desktop.
import h5py | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
class H5Dataset(Dataset): | |
def __init__(self, h5_path): | |
self.h5_path = h5_path | |
self.h5_file = h5py.File(h5_path, 'r') | |
self.length = len(h5py.File(h5_path, 'r')) | |
def __getitem__(self, index): | |
record = self.h5_file[str(index)] | |
return ( | |
record['data'].value, | |
record['target'].value, | |
) | |
def __len__(self): | |
return self.length | |
# -- | |
# Make data | |
f = h5py.File('test.h5') | |
for i in range(256): | |
f['%s/data' % i] = np.random.uniform(0, 1, (1024, 1024)) | |
f['%s/target' % i] = np.random.choice(1000) | |
# Runs correctly | |
dataloader = torch.utils.data.DataLoader( | |
H5Dataset('test.h5'), | |
batch_size=32, | |
num_workers=0, | |
shuffle=True | |
) | |
for i,(data,target) in enumerate(dataloader): | |
print(data.shape) | |
if i > 10: | |
break | |
# Throws error (sometimes, may have to restart python) | |
dataloader = torch.utils.data.DataLoader( | |
H5Dataset('test.h5'), | |
batch_size=32, | |
num_workers=8, | |
shuffle=True | |
) | |
for i,(data,target) in enumerate(dataloader): | |
print(data.shape) | |
if i > 10: | |
break | |
# KeyError: 'Traceback (most recent call last): | |
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop | |
# samples = collate_fn([dataset[i] for i in batch_indices]) | |
# File "<stdin>", line 11, in __getitem__ | |
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper | |
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper | |
# File "/home/bjohnson/.anaconda/lib/python2.7/site-packages/h5py/_hl/group.py", line 167, in __getitem__ | |
# oid = h5o.open(self.id, self._e(name), lapl=self._lapl) | |
# File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper | |
# File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper | |
# File "h5py/h5o.pyx", line 190, in h5py.h5o.open | |
# KeyError: Unable to open object (bad object header version number) |
@ssnl
According to the answer in this link https://stackoverflow.com/questions/46045512/h5py-hdf5-database-randomly-returning-nans-and-near-very-small-data-with-multi/52438133?noredirect=1#comment91819173_52438133I encountered the very same issue, and after spending a day trying to marry PyTorch DataParallel loader wrapper with HDF5 via h5py, I discovered that it is crucial to open h5py.File inside the new process, rather than having it opened in the main process and hope it gets inherited by the underlying multiprocessing implementation.
Since PyTorch seems to adopt lazy way of initializing workers, this means that the actual file opening has to happen inside of the getitem function of the Dataset wrapper.
`
class DeephomographyDataset(Dataset):def __init__(self,hdf5file,imgs_key='images',labels_key='labels', transform=None): self.hdf5file=hdf5file self.imgs_key=imgs_key self.labels_key=labels_key self.transform=transform def __len__(self): # return len(self.db[self.labels_key]) with h5py.File(self.hdf5file, 'r') as db: lens=len(db[self.labels_key]) return lens def __getitem__(self, idx): with h5py.File(self.hdf5file,'r') as db: image=db[self.imgs_key][idx] label=db[self.labels_key][idx] sample={'images':image,'labels':label} if self.transform: sample=self.transform(sample) return sample
`
But can you give me a detailed explaination of the answer in the following link?
I open H5 file in getitem fuction, but I also met this problem. How do you handle it anyway?
@soumith Is it right?