Skip to content

Instantly share code, notes, and snippets.

@phimachine
Last active April 10, 2020 12:29
Show Gist options
  • Save phimachine/12f2321e6c8fa53058f5eb23aeddb6ab to your computer and use it in GitHub Desktop.
Save phimachine/12f2321e6c8fa53058f5eb23aeddb6ab to your computer and use it in GitHub Desktop.
This is a pytorch generic function that takes a data.Dataset object and splits it to validation and training efficiently.
import np
from torch.utils.data import Dataset
class GenHelper(Dataset):
def __init__(self, mother, length, mapping):
# here is a mapping from this index to the mother ds index
self.mapping=mapping
self.length=length
self.mother=mother
def __getitem__(self, index):
return self.mother[self.mapping[index]]
def __len__(self):
return self.length
def train_valid_split(ds, split_fold=10, random_seed=None):
'''
This is a pytorch generic function that takes a data.Dataset object and splits it to validation and training
efficiently.
:return:
'''
if random_seed!=None:
np.random.seed(random_seed)
dslen=len(ds)
indices= list(range(dslen))
valid_size=dslen//split_fold
np.random.shuffle(indices)
train_mapping=indices[valid_size:]
valid_mapping=indices[:valid_size]
train=GenHelper(ds, dslen - valid_size, train_mapping)
valid=GenHelper(ds, valid_size, valid_mapping)
return train, valid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment