Skip to content

Instantly share code, notes, and snippets.

@syaffers
Created May 15, 2019 11:52
Show Gist options
  • Save syaffers/263ffc616a498d1892c0b764063b18bd to your computer and use it in GitHub Desktop.
Save syaffers/263ffc616a498d1892c0b764063b18bd to your computer and use it in GitHub Desktop.
Create validation sets by splitting your custom PyTorch datasets easily
import string
from torch.utils.data import DataLoader, random_split
from datasets import TESNamesDataset
data_root = '/home/syafiq/Data/tes-names/'
charset = string.ascii_letters + "-' "
length = 30
dataset = TESNamesDataset(data_root, charset, length)
trainset, valset = random_split(dataset, [15593, 3898])
train_loader = DataLoader(trainset, batch_size=10, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=10, shuffle=True, num_workers=2)
for i, batch in enumerate(train_loader):
print(i, batch)
for i, batch in enumerate(val_loader):
print(i, batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment