Last active
October 8, 2019 13:14
-
-
Save thebirk/60f376d8ddbeb078f6aa2d1405a20356 to your computer and use it in GitHub Desktop.
sdfsdf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import glob | |
import cv2 | |
import torch | |
from PIL import Image | |
from torch import nn | |
from torchvision import transforms | |
from torch.utils.data import Dataset, DataLoader | |
class CatsAndDogsDataset(Dataset): | |
def __init__(self, f_name, transform): | |
self.transform = transform | |
self.dogcat_list = [] | |
cats_paths = os.path.join(f_name, "cats", "*.jpg") | |
cats = glob.glob(cats_paths) | |
dog_paths = os.path.join(f_name, "dogs", "*.jpg") | |
dogs = glob.glob(dog_paths) | |
## 0 = dogs | |
## 1 = cats | |
dog_constant = torch.zeros(1) | |
cat_constant = torch.ones(1) | |
for path in dogs: | |
item = (dog_constant, path) | |
self.dogcat_list.append(item) | |
for path in cats: | |
item = (cat_constant, path) | |
self.dogcat_list.append(item) | |
def __len__(self): | |
return len(self.dogcat_list) | |
def __getitem__(self, idx): | |
filename = self.dogcat_list[idx][1] | |
classCategory = self.dogcat_list[idx][0] | |
im = Image.open(filename) | |
if self.transform: | |
im = self.transform(im) | |
return im.view(-1), classCategory | |
class CatsAndDogsModel(nn.Module): | |
def __init__(self, image_size, out_dim): | |
super(CatsAndDogsModel, self).__init__() | |
in_dim = image_size[0] * image_size[1] | |
self.l1 = nn.Linear(in_dim, out_dim) | |
self.sig = nn.Sigmoid() | |
def forward(self, x): | |
out = self.l1(x) | |
out = self.sig(out) | |
return out | |
if __name__ == "__main__": | |
image_size = (100, 100) | |
f_name = r"dataset/training_set" | |
transform = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.Grayscale(), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
dataset = CatsAndDogsDataset(f_name, transform) | |
dataloader = DataLoader(dataset, batch_size=64, shuffle=True) | |
model = CatsAndDogsModel(image_size, 1) | |
# Find Learning rate - 0.001 | |
# Find optmiser - SGD | |
# Find loss function - MSE(Mean squared error), BCE(binary cross entropy) | |
learning_rate = 0.001 | |
optimiser = torch.optim.SGD(model.parameters(), lr=learning_rate) | |
criterion = nn.MSELoss() | |
model.train(True) | |
epochs = 5 | |
for epoch in range(epochs): | |
print("epoch #{}".format(epoch+1)) | |
running_loss = 0 | |
for images, labels in dataloader: | |
optimiser.zero_grad() | |
output = model.forward(images) | |
loss = criterion(output, labels) | |
loss.backward() | |
optimiser.step() | |
running_loss += loss | |
else: | |
print(f"Total epoch loss: {running_loss/len(dataloader)}") | |
model.train(False) | |
for (images, cats) in dataloader: | |
for i in range(len(cats)): | |
out = model.forward(images[i]) | |
print("model out: {}, is_cat: {}".format(float(out), float(cats[i]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment