Created
April 12, 2023 12:33
-
-
Save keuv-grvl/6efe35e769be80020d60ea83a034e491 to your computer and use it in GitHub Desktop.
Wrap a HF `datasets.Dataset` into `torch.utils.data.Dataset`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import datasets | |
class HFDataset(torch.utils.data.Dataset): | |
def __init__(self, dset: datasets.Dataset): | |
self.dset = dset | |
def __getitem__(self, idx): | |
return self.dset[idx] | |
def __len__(self): | |
return len(self.dset) | |
if __name__ == "__main__": | |
# load a dataset from HF hub | |
hf_ds = datasets.load_dataset("the-dataset") | |
# process your data | |
def trsfm_fn(example) -> dict: | |
return {...: ...} | |
hf_ds = hf_ds.map(...).sort(...).filter(...).remove_columns(...).with_transform(trsfm_fn) | |
# wrap as a pytorch Dataset | |
train_ds = HFDataset(hf_ds["train"]) | |
train_ds[123:132] # the transform function is called once, good | |
# build dataloader | |
train_dl = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=3, drop_last=True) | |
ii = iter(train_dl) | |
example_batch = next(tain_dl) | |
# NOTE: .with_transform(...) is applied to each row |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment