Created
September 9, 2021 18:11
-
-
Save edraizen/1d6572b798e3cae94d4cd8961d663ca5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset as _Dataset | |
from torch.utils.data import Subset | |
import h5pyd | |
class DistributedDataset(_Dataset): | |
"""Read dataset from h5 file. If key specifies a dataset, each row is an | |
independent sample. If kay specifies a group, each dataset is an independent | |
sample. | |
Paramaters | |
---------- | |
path : str | |
The name of the full h5 file with all groups and datasets | |
key : str | |
The key to the dataset or group, specifying all intermediate groups | |
test : bool | |
Set mode to testing. This saves all rows or datasets IDs as an embedding | |
to compare against | |
dataset_group_name : str | |
Name of group to split datasets and data_splits. Default is 'datasets' | |
file_mode : str | |
Open h5 file for reading and or writing. Writing should only be used if | |
creating data_splits | |
""" | |
def __init__(self, path, key, test=False, dataset_group_name="datasets", file_mode="r"): | |
self.path = path | |
self.key = key | |
self.test = test | |
self.file_mode = file_mode | |
self.f = h5pyd.File(path, file_mode, use_cache=False) | |
self.data = self.f[key] | |
self.embedding = None | |
if not isinstance(self.data, h5pyd.Dataset): | |
if dataset_group_name in self.data.keys(): | |
self.data = self.f[f'{key}/{dataset_group_name}'] | |
self.order = sorted(self.data.keys()) | |
if self.test: | |
from sklearn import preprocessing | |
self.embedding = preprocessing.LabelEncoder().fit(self.order) | |
else: | |
self.order = list(range(len(self.data))) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
ds = self.data[self.order[index]] | |
return ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment