Last active
January 20, 2024 11:47
-
-
Save codinguncut/096e45204e2324c5e08e9af8a5c25da8 to your computer and use it in GitHub Desktop.
dogs_and_cats loader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import datasets, transforms | |
import pathlib | |
import random | |
def get_dogs_and_cats(*args, **kwargs): | |
return get_dogs_and_cats_data(*args, **kwargs) | |
def get_dogs_and_cats_data(split="train", resize=(32,32), n_images=None, batch_size=None, is_resnet=False, **kwargs): | |
if resize is None: | |
resize = (256, 256) | |
transform = transforms.Compose([ | |
transforms.Resize(size=resize), | |
transforms.ToTensor(), | |
]) | |
folder = pathlib.Path(__file__).resolve().parent / "dogs_and_cats" / split | |
dataset = datasets.ImageFolder(folder, transform=transform) | |
if is_resnet: | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **kwargs) | |
return dataloader | |
else: | |
sampler = None | |
if n_images: | |
sampler = random.sample(list(range(len(dataset))), n_images) | |
batch_size = n_images if n_images else len(dataset) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler, **kwargs) | |
return next(iter(dataloader)) | |
def to_image_transform(): | |
return transforms.ToPILImage() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Making sure we can find the data loader | |
import sys | |
sys.path.append('..') | |
from data import load | |
import torch.utils.tensorboard as tb | |
import logging | |
import numpy as np | |
def train(model, log_dir, batch_size=128, resize=(32,32), device="cpu", n_epochs=100): | |
logging.warning("loading dataset") | |
train_data, train_label = load.get_dogs_and_cats_data(resize=(32,32)) | |
valid_data, valid_label = load.get_dogs_and_cats_data(split='valid', resize=(32,32)) | |
logging.warning("loading done") | |
input_size = 32*32*3 | |
to_image = load.to_image_transform() | |
train_data, train_label = train_data.to(device), train_label.to(device) | |
valid_data, valid_label = valid_data.to(device), valid_label.to(device) | |
train_logger = tb.SummaryWriter(log_dir+'/deepnet1/train', flush_secs=1) | |
valid_logger = tb.SummaryWriter(log_dir+'/deepnet1/valid', flush_secs=1) | |
# Create the optimizer | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4) | |
# Create the loss | |
loss = torch.nn.BCEWithLogitsLoss() | |
# Start training | |
global_step = 0 | |
for epoch in range(n_epochs): | |
# Shuffle the data | |
permutation = torch.randperm(train_data.size(0)) | |
# Iterate | |
train_accuracy = [] | |
for it in range(0, len(permutation)-batch_size+1, batch_size): | |
batch_samples = permutation[it:it+batch_size] | |
batch_data, batch_label = train_data[batch_samples], train_label[batch_samples] | |
# Compute the loss | |
o = model(batch_data) | |
loss_val = loss(o, batch_label.float()) | |
train_logger.add_scalar('train/loss', loss_val, global_step=global_step) | |
# Compute the accuracy | |
train_accuracy.extend(((o > 0).long() == batch_label).cpu().detach().numpy()) | |
optimizer.zero_grad() | |
loss_val.backward() | |
optimizer.step() | |
# Increase the global step | |
global_step += 1 | |
# Evaluate the model | |
valid_pred = model(valid_data) > 0 | |
valid_accuracy = float((valid_pred.long() == valid_label).float().mean()) | |
train_logger.add_scalar('train/accuracy', np.mean(train_accuracy), global_step=global_step) | |
valid_logger.add_scalar('valid/accuracy', valid_accuracy, global_step=global_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Expected directory layout: