-
-
Save Miladiouss/6ba0876f0e2b65d0178be7274f61ad2f to your computer and use it in GitHub Desktop.
import torchvision | |
import torchvision.transforms as transforms | |
from torchvision.datasets import CIFAR10 | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
# Transformations | |
RC = transforms.RandomCrop(32, padding=4) | |
RHF = transforms.RandomHorizontalFlip() | |
RVF = transforms.RandomVerticalFlip() | |
NRM = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
TT = transforms.ToTensor() | |
TPIL = transforms.ToPILImage() | |
# Transforms object for trainset with augmentation | |
transform_with_aug = transforms.Compose([TPIL, RC, RHF, TT, NRM]) | |
# Transforms object for testset with NO augmentation | |
transform_no_aug = transforms.Compose([TT, NRM]) | |
# Downloading/Louding CIFAR10 data | |
# , transform = transform_with_aug) | |
trainset = CIFAR10(root='./data', train=True, download=True) | |
# , transform = transform_no_aug) | |
testset = CIFAR10(root='./data', train=False, download=True) | |
classDict = {'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4, | |
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9} | |
# Separating trainset/testset data/label | |
x_train = trainset.data | |
x_test = testset.data | |
y_train = trainset.targets | |
y_test = testset.targets | |
# Define a function to separate CIFAR classes by class index | |
def get_class_i(x, y, i): | |
""" | |
x: trainset.train_data or testset.test_data | |
y: trainset.train_labels or testset.test_labels | |
i: class label, a number between 0 to 9 | |
return: x_i | |
""" | |
# Convert to a numpy array | |
y = np.array(y) | |
# Locate position of labels that equal to i | |
pos_i = np.argwhere(y == i) | |
# Convert the result into a 1-D list | |
pos_i = list(pos_i[:, 0]) | |
# Collect all data that match the desired label | |
x_i = [x[j] for j in pos_i] | |
return x_i | |
class DatasetMaker(Dataset): | |
def __init__(self, datasets, transformFunc=transform_no_aug): | |
""" | |
datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes | |
""" | |
self.datasets = datasets | |
self.lengths = [len(d) for d in self.datasets] | |
self.transformFunc = transformFunc | |
def __getitem__(self, i): | |
class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i) | |
img = self.datasets[class_label][index_wrt_class] | |
img = self.transformFunc(img) | |
return img, class_label | |
def __len__(self): | |
return sum(self.lengths) | |
def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False): | |
""" | |
Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to. | |
""" | |
# Which class/bin does i fall into? | |
accum = np.add.accumulate(bin_sizes) | |
if verbose: | |
print("accum =", accum) | |
bin_index = len(np.argwhere(accum <= absolute_index)) | |
if verbose: | |
print("class_label =", bin_index) | |
# Which element of the fallent class/bin does i correspond to? | |
index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index] | |
if verbose: | |
print("index_wrt_class =", index_wrt_class) | |
return bin_index, index_wrt_class | |
# ================== Usage ================== # | |
# Let's choose cats (class 3 of CIFAR) and dogs (class 5 of CIFAR) as trainset/testset | |
cat_dog_trainset = \ | |
DatasetMaker( | |
[get_class_i(x_train, y_train, classDict['cat']), | |
get_class_i(x_train, y_train, classDict['dog'])], | |
transform_with_aug | |
) | |
cat_dog_testset = \ | |
DatasetMaker( | |
[get_class_i(x_test, y_test, classDict['cat']), | |
get_class_i(x_test, y_test, classDict['dog'])], | |
transform_no_aug | |
) | |
kwargs = {'num_workers': 2, 'pin_memory': False} | |
# Create datasetLoaders from trainset and testset | |
trainsetLoader = DataLoader( | |
cat_dog_trainset, batch_size=64, shuffle=True, **kwargs) | |
testsetLoader = DataLoader( | |
cat_dog_testset, batch_size=64, shuffle=False, **kwargs) |
Thanks for this! Got a small error using this code. Need to convert the output images from numpy to PIL
@hyeongminoh
You could just change those 4 lines as follows.
x_train = trainset.data
x_test = testset.data
y_train = trainset.targets
y_test = testset.targets
img = Image.fromarray(img)
this line is needed before the transformation as @agoel10 said.- Running your code, I've found out that target values are actually returned wrt ordering of your queries, such as cat and dog above meaning that cat is re-assigned to a target value 0 and dog to 1. How can I keep and return the original target values?
@Seohyeong how do you print trainsetLoader/ testsetLoader target values. Thank you.
I just updated this to work with the latest PyTorch ('1.9.1') and torch vision ('0.10.1').
I did it like this with e for the classes and k for the number of elements, 100 planes, cars and frogs
`dataset = CIFAR10(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
index = random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 0], k=100)
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 1], k=100))
index = np.append(index, random.choices([i for i ,e in enumerate(list(dataset.targets)) if e == 6], k=100))
data = torch.utils.data.Subset(dataset, index)
train_data,val_data, test_data = torch.utils.data.random_split(data,[200,50,50])`
thank you!
But I got porblem
AttributeError: 'CIFAR10' object has no attribute 'train_data'
how can i fix it?