Last active
October 27, 2023 03:41
-
-
Save Miladiouss/6ba0876f0e2b65d0178be7274f61ad2f to your computer and use it in GitHub Desktop.
Create PyTorch datasets and dataset loaders for a subset of CIFAR10 classes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
import torchvision.transforms as transforms | |
from torchvision.datasets import CIFAR10 | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
# Transformations | |
RC = transforms.RandomCrop(32, padding=4) | |
RHF = transforms.RandomHorizontalFlip() | |
RVF = transforms.RandomVerticalFlip() | |
NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
TT = transforms.ToTensor() | |
TPIL = transforms.ToPILImage() | |
# Transforms object for trainset with augmentation | |
transform_with_aug = transforms.Compose([TPIL, RC, RHF, TT, NRM]) | |
# Transforms object for testset with NO augmentation | |
transform_no_aug = transforms.Compose([TT, NRM]) | |
# Downloading/Louding CIFAR10 data | |
# , transform = transform_with_aug) | |
trainset = CIFAR10(root='./data', train=True, download=True) | |
# , transform = transform_no_aug) | |
testset = CIFAR10(root='./data', train=False, download=True) | |
classDict = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, | |
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} | |
# Separating trainset/testset data/label | |
x_train = trainset.data | |
x_test = testset.data | |
y_train = trainset.targets | |
y_test = testset.targets | |
# Define a function to separate CIFAR classes by class index | |
def get_class_i(x, y, i): | |
""" | |
x: trainset.train_data or testset.test_data | |
y: trainset.train_labels or testset.test_labels | |
i: class label, a number between 0 to 9 | |
return: x_i | |
""" | |
# Convert to a numpy array | |
y = np.array(y) | |
# Locate position of labels that equal to i | |
pos_i = np.argwhere(y == i) | |
# Convert the result into a 1-D list | |
pos_i = list(pos_i[:, 0]) | |
# Collect all data that match the desired label | |
x_i = [x[j] for j in pos_i] | |
return x_i | |
class DatasetMaker(Dataset): | |
def __init__(self, datasets, transformFunc=transform_no_aug): | |
""" | |
datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes | |
""" | |
self.datasets = datasets | |
self.lengths = [len(d) for d in self.datasets] | |
self.transformFunc = transformFunc | |
def __getitem__(self, i): | |
class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i) | |
img = self.datasets[class_label][index_wrt_class] | |
img = self.transformFunc(img) | |
return img, class_label | |
def __len__(self): | |
return sum(self.lengths) | |
def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False): | |
""" | |
Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to. | |
""" | |
# Which class/bin does i fall into? | |
accum = np.add.accumulate(bin_sizes) | |
if verbose: | |
print("accum =", accum) | |
bin_index = len(np.argwhere(accum <= absolute_index)) | |
if verbose: | |
print("class_label =", bin_index) | |
# Which element of the fallent class/bin does i correspond to? | |
index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index] | |
if verbose: | |
print("index_wrt_class =", index_wrt_class) | |
return bin_index, index_wrt_class | |
# ================== Usage ================== # | |
# Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset | |
cat_dog_trainset = \ | |
DatasetMaker( | |
[get_class_i(x_train, y_train, classDict['cat']), | |
get_class_i(x_train, y_train, classDict['dog'])], | |
transform_with_aug | |
) | |
cat_dog_testset = \ | |
DatasetMaker( | |
[get_class_i(x_test, y_test, classDict['cat']), | |
get_class_i(x_test, y_test, classDict['dog'])], | |
transform_no_aug | |
) | |
kwargs = {'num_workers': 2, 'pin_memory': False} | |
# Create datasetLoaders from trainset and testset | |
trainsetLoader = DataLoader( | |
cat_dog_trainset, batch_size=64, shuffle=True, **kwargs) | |
testsetLoader = DataLoader( | |
cat_dog_testset, batch_size=64, shuffle=False, **kwargs) |
I did it like this with e for the classes and k for the number of elements, 100 planes, cars and frogs
`dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
index = random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 0], k=100)
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 1], k=100))
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 6], k=100))
data = torch.utils.data.Subset(dataset, index)
train_data,val_data, test_data = torch.utils.data.random_split(data,[200,50,50])`
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I just updated this to work with the latest PyTorch ('1.9.1') and torch vision ('0.10.1').