Last active
October 5, 2023 10:04
-
-
Save simonmoesorensen/ac590e8e25ac8b1c322519d2d8c73676 to your computer and use it in GitHub Desktop.
Creates a pytorch sampler that samples classes evenly. Utilizes vectorization and pytorch dataloaders to efficiently calculate weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import DataLoader, sampler | |
from torchvision import datasets | |
def make_weights_for_balanced_classes(images, nclasses, batch_size): | |
""" | |
Adapted from https://gist.github.com/srikarplus/15d7263ae2c82e82fe194fc94321f34e | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
count = torch.zeros(nclasses).to(device) | |
loader = DataLoader(images, batch_size=batch_size, num_workers=num_workers) | |
for _, label in tqdm(loader, desc="Counting classes"): | |
label = label.to(device=device) | |
idx, counts = label.unique(return_counts=True) | |
count[idx] += counts | |
N = count.sum() | |
weight_per_class = N / count | |
weight = torch.zeros(len(images)).to(device) | |
for i, (img, label) in tqdm(enumerate(loader), desc="Apply weights", total=len(loader)): | |
idx = torch.arange(0, img.shape[0]) + (i * batch_size) | |
idx = idx.to(dtype=torch.long, device=device) | |
weight[idx] = weight_per_class[label] | |
return weight | |
# Train set | |
dataset_train = datasets.ImageFolder(traindir) | |
# For unbalanced dataset we create a weighted sampler | |
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes), args.batch_size) | |
sampler = sampler.WeightedRandomSampler(weights, len(weights)) | |
train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle = False, | |
sampler = sampler, num_workers=args.workers, pin_memory=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment