Last active
December 8, 2023 13:00
-
-
Save branislav1991/4c143394bdad612883d148e0617bdccd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import h5py | |
import helpers | |
import numpy as np | |
from pathlib import Path | |
import torch | |
from torch.utils import data | |
class HDF5Dataset(data.Dataset): | |
"""Represents an abstract HDF5 dataset. | |
Input params: | |
file_path: Path to the folder containing the dataset (one or multiple HDF5 files). | |
recursive: If True, searches for h5 files in subdirectories. | |
load_data: If True, loads all the data immediately into RAM. Use this if | |
the dataset is fits into memory. Otherwise, leave this at false and | |
the data will load lazily. | |
data_cache_size: Number of HDF5 files that can be cached in the cache (default=3). | |
transform: PyTorch transform to apply to every data instance (default=None). | |
""" | |
def __init__(self, file_path, recursive, load_data, data_cache_size=3, transform=None): | |
super().__init__() | |
self.data_info = [] | |
self.data_cache = {} | |
self.data_cache_size = data_cache_size | |
self.transform = transform | |
# Search for all h5 files | |
p = Path(file_path) | |
assert(p.is_dir()) | |
if recursive: | |
files = sorted(p.glob('**/*.h5')) | |
else: | |
files = sorted(p.glob('*.h5')) | |
if len(files) < 1: | |
raise RuntimeError('No hdf5 datasets found') | |
for h5dataset_fp in files: | |
self._add_data_infos(str(h5dataset_fp.resolve()), load_data) | |
def __getitem__(self, index): | |
# get data | |
x = self.get_data("data", index) | |
if self.transform: | |
x = self.transform(x) | |
else: | |
x = torch.from_numpy(x) | |
# get label | |
y = self.get_data("label", index) | |
y = torch.from_numpy(y) | |
return (x, y) | |
def __len__(self): | |
return len(self.get_data_infos('data')) | |
def _add_data_infos(self, file_path, load_data): | |
with h5py.File(file_path) as h5_file: | |
# Walk through all groups, extracting datasets | |
for gname, group in h5_file.items(): | |
for dname, ds in group.items(): | |
# if data is not loaded its cache index is -1 | |
idx = -1 | |
if load_data: | |
# add data to the data cache | |
idx = self._add_to_cache(ds.value, file_path) | |
# type is derived from the name of the dataset; we expect the dataset | |
# name to have a name such as 'data' or 'label' to identify its type | |
# we also store the shape of the data in case we need it | |
self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx}) | |
def _load_data(self, file_path): | |
"""Load data to the cache given the file | |
path and update the cache index in the | |
data_info structure. | |
""" | |
with h5py.File(file_path) as h5_file: | |
for gname, group in h5_file.items(): | |
for dname, ds in group.items(): | |
# add data to the data cache and retrieve | |
# the cache index | |
idx = self._add_to_cache(ds.value, file_path) | |
# find the beginning index of the hdf5 file we are looking for | |
file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path) | |
# the data info should have the same index since we loaded it in the same way | |
self.data_info[file_idx + idx]['cache_idx'] = idx | |
# remove an element from data cache if size was exceeded | |
if len(self.data_cache) > self.data_cache_size: | |
# remove one item from the cache at random | |
removal_keys = list(self.data_cache) | |
removal_keys.remove(file_path) | |
self.data_cache.pop(removal_keys[0]) | |
# remove invalid cache_idx | |
self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info] | |
def _add_to_cache(self, data, file_path): | |
"""Adds data to the cache and returns its index. There is one cache | |
list for every file_path, containing all datasets in that file. | |
""" | |
if file_path not in self.data_cache: | |
self.data_cache[file_path] = [data] | |
else: | |
self.data_cache[file_path].append(data) | |
return len(self.data_cache[file_path]) - 1 | |
def get_data_infos(self, type): | |
"""Get data infos belonging to a certain type of data. | |
""" | |
data_info_type = [di for di in self.data_info if di['type'] == type] | |
return data_info_type | |
def get_data(self, type, i): | |
"""Call this function anytime you want to access a chunk of data from the | |
dataset. This will make sure that the data is loaded in case it is | |
not part of the data cache. | |
""" | |
fp = self.get_data_infos(type)[i]['file_path'] | |
if fp not in self.data_cache: | |
self._load_data(fp) | |
# get new cache_idx assigned by _load_data_info | |
cache_idx = self.get_data_infos(type)[i]['cache_idx'] | |
return self.data_cache[fp][cache_idx] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I did not attach a license but the code is free to use and share however you want.