Last active
October 8, 2018 19:45
-
-
Save conormm/2c179c59132e3bcde7af026509aa0ec8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler | |
from torchvision.datasets.folder import ImageFolder, default_loader | |
from torchvision.datasets.utils import check_integrity | |
from torchvision import transforms | |
from torchvision import models | |
import matplotlib.pyplot as plt | |
from src.utils_cm import ModelParameters | |
images_dir = "data/sample" | |
NUM_EPOCHS = 3 | |
IMG_SIZE = 250 | |
BATCH_SIZE = 10 | |
# these are standard pytorch values for image normalization | |
normmean = [0.485, 0.456, 0.406] | |
normstd = [0.229, 0.224, 0.225] | |
def fine_tuning_model(model, n_classes=120): | |
ModelParameters.freeze_all(model.parameters()) | |
assert ModelParameters.all_frozen(model.parameters()) | |
model.ft_layer = nn.Linear(1000, n_classes) | |
assert model.ft_layer.weight.requires_grad | |
return model | |
train_trans = transforms.Compose([ | |
transforms.Resize(IMG_SIZE), | |
transforms.RandomCrop(224), | |
transforms.ColorJitter(.3, .3, .3), | |
transforms.RandomHorizontalFlip(p=.3), | |
transforms.ToTensor(), | |
transforms.g | |
transforms.Normalize(normmean, normstd) | |
]) | |
val_trains = transforms.Compose([ | |
transforms.Resize(IMG_SIZE), | |
transforms.CenterCrop(), | |
transforms.ToTensor(), | |
transforms.Normalize() | |
]) | |
img_f = ImageFolder(images_dir, transform=train_trans) | |
n_classes = len(img_f.classes) | |
ds = DataLoader(img_f, batch_size=BATCH_SIZE, shuffle=True) | |
VGG16 = models.vgg16(pretrained=True) | |
VGG16 = fine_tuning_model(VGG16) | |
optim = torch.optim.Adam( | |
ModelParameters.get_trainable(VGG16.parameters()), | |
lr=0.001 | |
) | |
criterion = nn.CrossEntropyLoss() | |
VGG16.train() | |
for epoch in range(NUM_EPOCHS): | |
print(f"Epoch number {epoch}") | |
for ix, (X, y) in enumerate(ds): | |
optim.zero_grad() | |
X.requires_grad = True | |
preds = VGG16(X) | |
loss = criterion(preds, y) | |
loss.backward() | |
optim.step() | |
print(f"Loss: {loss.item()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment