Created
June 23, 2020 14:41
-
-
Save khalidmeister/719c3978cb4bc386ebd0af752ba6b5e9 to your computer and use it in GitHub Desktop.
Transform Image Data For The Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
import copy | |
import os | |
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
import torchvision | |
import matplotlib.pyplot as plt | |
from torch.optim import lr_scheduler | |
from torchvision import datasets, models, transforms | |
# Set the transformation for each dataset folder | |
transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'test': transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
} | |
# Import the dataset | |
data_dir = 'data' | |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transform=transforms[x]) | |
for x in ['train', 'val', 'test']} | |
# Shuffle the dataset and create batches from the dataset | |
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4) | |
for x in ['train', 'val', 'test']} | |
# Get the number of images in each folder | |
data_size = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']} | |
# Get the class name | |
class_names = image_datasets['train'].classes | |
# Enable the GPU if it exists | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment