Skip to content

Instantly share code, notes, and snippets.

@biswajitcsecu
Last active April 30, 2021 01:35
Show Gist options
  • Save biswajitcsecu/939fc8eae74e2a575c1fe4c8048022c4 to your computer and use it in GitHub Desktop.
Save biswajitcsecu/939fc8eae74e2a575c1fe4c8048022c4 to your computer and use it in GitHub Desktop.
ImageClassificationTorchDemo
from __future__ import print_function, division
import os
import torch
from torchvision import transforms, datasets
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.data as data
import torchvision
from torchvision import transforms
import warnings
warnings.filterwarnings("ignore")
train_data_path= '/home/donvex/Projects/DLCNN/hymenoptera/train/'
test_data_path= '/home/donvex/Projects/DLCNN/hymenoptera/val/'
mean =[.4363,.4328, .3292]
std =[.2129,.2075, .2038]
train_transforms = transforms.Compose([
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
])
test_transforms = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
train_dataset = torchvision.datasets.ImageFolder(root=train_data_path,transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root=test_data_path,transform=test_transforms)
def show_transformed_image(dataset):
loder = torch.utils.data.DataLoader(dataset,batch_size=8,shuffle=True)
batch = next(iter(loder))
images,labels = batch
grid= torchvision.utils.make_grid(images,nrow=4)
fig=plt.figure(figsize=(12,10))
plt.imshow(np.transpose(grid,(1,2,0)))
print('labes: ',labels)
fig.tight_layout()
plt.show()
plt.close(fig)
show_transformed_image(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=32,shuffle=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment