Skip to content

Instantly share code, notes, and snippets.

@archydeberker
Last active April 14, 2019 21:30
Show Gist options
  • Save archydeberker/6342b6a4241c8b9b1415f33a3d8ebb91 to your computer and use it in GitHub Desktop.
Save archydeberker/6342b6a4241c8b9b1415f33a3d8ebb91 to your computer and use it in GitHub Desktop.
Subsample data for the purpose of training a CIFAR classifier with varying dataset size
def get_dataset_size(start=0.5, end=100, base=2):
""" Returns exponentially distributed dataset size vector"""
dataset_size=[start]
while True:
dataset_size.append(dataset_size[-1]*base)
if dataset_size[-1] > end:
dataset_size[-1] = end
break
return dataset_size
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
val_size = 0.2
num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(val_size * num_train))
np.random.shuffle(indices)
train_idx, val_idx = indices[split:], indices[:split]
total_train = len(train_idx)
# For each of our train sets, we want a subset of the true train set
dataset_size = np.array(get_dataset_size())
dataset_size /=100 # Convert to fraction of original dataset size
train_set_samplers = dict()
trainset_loaders = dict()
for ts in dataset_size:
train_set_samplers[ts]=np.random.choice(train_idx, int(ts*total_train))
trainset_loaders[ts]=torch.utils.data.DataLoader(trainset, batch_size=4,
sampler=train_set_samplers[ts], num_workers=2)
val_sampler = SubsetRandomSampler(val_idx)
valloader = torch.utils.data.DataLoader(trainset, batch_size=4,
sampler=val_sampler, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment