Created
December 1, 2018 08:17
-
-
Save srikarplus/15d7263ae2c82e82fe194fc94321f34e to your computer and use it in GitHub Desktop.
Stratified Sampling in Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def make_weights_for_balanced_classes(images, nclasses): | |
count = [0] * nclasses | |
for item in images: | |
count[item[1]] += 1 | |
weight_per_class = [0.] * nclasses | |
N = float(sum(count)) | |
for i in range(nclasses): | |
weight_per_class[i] = N/float(count[i]) | |
weight = [0] * len(images) | |
for idx, val in enumerate(images): | |
weight[idx] = weight_per_class[val[1]] | |
return weight | |
# And after this, use it in the next way: | |
dataset_train = datasets.ImageFolder(traindir) | |
# For unbalanced dataset we create a weighted sampler | |
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes)) | |
weights = torch.DoubleTensor(weights) | |
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) | |
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True, | |
sampler = sampler, num_workers=args.workers, pin_memory=True) |
For those who need an implementation for large datasets:
https://gist.github.com/simonmoesorensen/ac590e8e25ac8b1c322519d2d8c73676
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Thanks for the code sample. But sampler option is mutually exclusive with shuffle option. So need to set shuffle=False when using sampler.