-
-
Save Miladiouss/6ba0876f0e2b65d0178be7274f61ad2f to your computer and use it in GitHub Desktop.
| import torchvision | |
| import torchvision.transforms as transforms | |
| from torchvision.datasets import CIFAR10 | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| # Transformations | |
| RC = transforms.RandomCrop(32, padding=4) | |
| RHF = transforms.RandomHorizontalFlip() | |
| RVF = transforms.RandomVerticalFlip() | |
| NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
| TT = transforms.ToTensor() | |
| TPIL = transforms.ToPILImage() | |
| # Transforms object for trainset with augmentation | |
| transform_with_aug = transforms.Compose([TPIL, RC, RHF, TT, NRM]) | |
| # Transforms object for testset with NO augmentation | |
| transform_no_aug = transforms.Compose([TT, NRM]) | |
| # Downloading/Louding CIFAR10 data | |
| # , transform = transform_with_aug) | |
| trainset = CIFAR10(root='./data', train=True, download=True) | |
| # , transform = transform_no_aug) | |
| testset = CIFAR10(root='./data', train=False, download=True) | |
| classDict = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, | |
| 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} | |
| # Separating trainset/testset data/label | |
| x_train = trainset.data | |
| x_test = testset.data | |
| y_train = trainset.targets | |
| y_test = testset.targets | |
| # Define a function to separate CIFAR classes by class index | |
| def get_class_i(x, y, i): | |
| """ | |
| x: trainset.train_data or testset.test_data | |
| y: trainset.train_labels or testset.test_labels | |
| i: class label, a number between 0 to 9 | |
| return: x_i | |
| """ | |
| # Convert to a numpy array | |
| y = np.array(y) | |
| # Locate position of labels that equal to i | |
| pos_i = np.argwhere(y == i) | |
| # Convert the result into a 1-D list | |
| pos_i = list(pos_i[:, 0]) | |
| # Collect all data that match the desired label | |
| x_i = [x[j] for j in pos_i] | |
| return x_i | |
| class DatasetMaker(Dataset): | |
| def __init__(self, datasets, transformFunc=transform_no_aug): | |
| """ | |
| datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes | |
| """ | |
| self.datasets = datasets | |
| self.lengths = [len(d) for d in self.datasets] | |
| self.transformFunc = transformFunc | |
| def __getitem__(self, i): | |
| class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i) | |
| img = self.datasets[class_label][index_wrt_class] | |
| img = self.transformFunc(img) | |
| return img, class_label | |
| def __len__(self): | |
| return sum(self.lengths) | |
| def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False): | |
| """ | |
| Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to. | |
| """ | |
| # Which class/bin does i fall into? | |
| accum = np.add.accumulate(bin_sizes) | |
| if verbose: | |
| print("accum =", accum) | |
| bin_index = len(np.argwhere(accum <= absolute_index)) | |
| if verbose: | |
| print("class_label =", bin_index) | |
| # Which element of the fallent class/bin does i correspond to? | |
| index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index] | |
| if verbose: | |
| print("index_wrt_class =", index_wrt_class) | |
| return bin_index, index_wrt_class | |
| # ================== Usage ================== # | |
| # Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset | |
| cat_dog_trainset = \ | |
| DatasetMaker( | |
| [get_class_i(x_train, y_train, classDict['cat']), | |
| get_class_i(x_train, y_train, classDict['dog'])], | |
| transform_with_aug | |
| ) | |
| cat_dog_testset = \ | |
| DatasetMaker( | |
| [get_class_i(x_test, y_test, classDict['cat']), | |
| get_class_i(x_test, y_test, classDict['dog'])], | |
| transform_no_aug | |
| ) | |
| kwargs = {'num_workers': 2, 'pin_memory': False} | |
| # Create datasetLoaders from trainset and testset | |
| trainsetLoader = DataLoader( | |
| cat_dog_trainset, batch_size=64, shuffle=True, **kwargs) | |
| testsetLoader = DataLoader( | |
| cat_dog_testset, batch_size=64, shuffle=False, **kwargs) |
Thanks for this! Got a small error using this code. Need to convert the output images from numpy to PIL
@hyeongminoh
You could just change those 4 lines as follows.
x_train = trainset.data
x_test = testset.data
y_train = trainset.targets
y_test = testset.targets
img = Image.fromarray(img)this line is needed before the transformation as @agoel10 said.- Running your code, I've found out that target values are actually returned wrt ordering of your queries, such as cat and dog above meaning that cat is re-assigned to a target value 0 and dog to 1. How can I keep and return the original target values?
@Seohyeong how do you print trainsetLoader/ testsetLoader target values. Thank you.
I just updated this to work with the latest PyTorch ('1.9.1') and torch vision ('0.10.1').
I did it like this with e for the classes and k for the number of elements, 100 planes, cars and frogs
`dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
index = random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 0], k=100)
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 1], k=100))
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 6], k=100))
data = torch.utils.data.Subset(dataset, index)
train_data,val_data, test_data = torch.utils.data.random_split(data,[200,50,50])`
thank you!
But I got porblem
AttributeError: 'CIFAR10' object has no attribute 'train_data'
how can i fix it?