Last active
June 23, 2020 14:44
-
-
Save sharma0611/81e895698564bf804f05f001fe3807ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision | |
from modules.cifar10 import data_loader | |
import matplotlib.pyplot as plt | |
# modules.utils.py | |
class DeNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, tensor): | |
""" | |
Args: | |
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
Returns: | |
Tensor: Normalized image. | |
""" | |
for t, m, s in zip(tensor, self.mean, self.std): | |
t.mul_(s).add_(m) | |
# The normalize code -> t.sub_(m).div_(s) | |
return tensor | |
# modules.cifar10.py | |
def denormalize_transform(): | |
denormal = DeNormalize( | |
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) | |
return denormal | |
# Sanity Check in a notebook cell | |
train_loader, val_loader = data_loader('./data', batch_size=3) | |
(image, target) = iter(train_loader).next() | |
# Denormalize RGB values since our data loader has a normalize | |
# Values above or below [0,1] are clipped in the RGB image displayed | |
transform = denormalize_transform() | |
transform(image) | |
grid_img = torchvision.utils.make_grid(image, nrow=3) | |
plt.imshow(grid_img.permute(1, 2, 0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment