Created
February 18, 2021 13:49
-
-
Save morawi/42d9a99e135e832f9fd54fb678151efb to your computer and use it in GitHub Desktop.
Using/Tesgin PyTorch SubSetRandSampler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import SubsetRandomSampler as SubSetRandSampler | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import matplotlib.pyplot as plt | |
my_sampler_size = 5000 # I'm going to randomly sample 5000 items from the dataset | |
# Why using this instead of shuffle? Because I want to get less samples than the whole dataset | |
# would using an if statement inside the validation function be good ... something like if i>my_sampler_size: break | |
val_transforms = transforms.Compose([ | |
# transforms.Resize(image_size, interpolation=Image.BICUBIC), | |
# transforms.CenterCrop(crop_size), | |
transforms.ToTensor(), | |
# normalize, | |
]) | |
val_dataset = datasets.CIFAR10('../data', train= False, download=True, | |
transform = val_transforms, | |
) | |
validate_significance_loader = torch.utils.data.DataLoader( | |
val_dataset, | |
batch_size=100, | |
num_workers=0, | |
sampler = SubSetRandSampler(torch.randint(0, len(val_dataset), (my_sampler_size,)) ), # SubSetRandSampler(range(1000)), | |
shuffle= False, | |
pin_memory=True) | |
# Let's check it out | |
sum=0 | |
for i, (images, target) in enumerate(validate_significance_loader): | |
sum += len(target) | |
if i<2: | |
plt.imshow(images[i,:].permute(1, 2, 0) ); plt.show() # to see the image | |
print(target[i]) | |
break | |
print(sum) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment