Created
April 25, 2019 11:45
-
-
Save azkalot1/35b0827d5331c6fa4f509e842ae32651 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torch.autograd import Variable | |
from torchvision import datasets, transforms | |
from torch.optim import Optimizer | |
from torch.utils import data | |
import pretrainedmodels | |
import numpy as np | |
import os | |
import cv2 | |
from skimage.io import imread | |
from torch.utils.data.sampler import WeightedRandomSampler, BatchSampler | |
from albumentations import ( | |
HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90, Normalize, RandomGamma, RandomBrightnessContrast, HueSaturationValue, CLAHE, ChannelShuffle, | |
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, | |
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, | |
IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose, PadIfNeeded, RandomCrop, Resize | |
) | |
import pretrainedmodels.utils as utils | |
from sklearn.metrics import roc_auc_score | |
def train(model, train_loader, optimizer, epoch, log_interval, loss_f, samples_per_epoch, device, cycling_optimizer=False): | |
"""Trains the model using the provided optimizer and loss function. | |
Shows output each log_interval iterations | |
Args: | |
model: Pytorch model to train. | |
train_loader: Data loader. | |
optimizer: pytroch optimizer. | |
epoch: Current epoch. | |
log_interval: Show model training progress each log_interval steps. | |
loss_f: Loss function to optimize. | |
samples_per_epoch: Number of samples per epoch to scale loss. | |
device: pytorch device | |
cycling_optimizer: Indicates of optimizer is cycling. | |
""" | |
model.train() | |
total_losses = [] | |
losses =[] | |
for batch_idx, (x, target) in enumerate(train_loader): | |
optimizer.zero_grad() | |
output = model(x.to(device, dtype=torch.float)) | |
loss = loss_f(output, target.to(device, dtype=torch.float)) | |
losses.append(loss.item()) | |
loss.backward() | |
nn.utils.clip_grad_norm(model.parameters(), 4) | |
if cycling_optimizer: | |
optimizer.batch_step() | |
else: | |
optimizer.step() | |
if batch_idx % log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.3f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(x), samples_per_epoch, | |
100. * batch_idx * len(x) / samples_per_epoch, np.mean(losses))) | |
total_losses.append(np.mean(losses)) | |
losses = [] | |
train_loss_mean = np.mean(total_losses) | |
print('Mean train loss on epoch {} : {}'.format(epoch, train_loss_mean)) | |
return train_loss_mean | |
def test(model, test_loader, loss_f, epoch, device): | |
"""Test the model with validation data. | |
Args: | |
model: Pytorch model to test data with. | |
test_loader: Data loader. | |
loss_f: Loss function. | |
epoch: Current epoch. | |
device: pytorch device | |
""" | |
model.eval() | |
test_loss = 0 | |
predictions=[] | |
targets=[] | |
test_loss=[] | |
with torch.no_grad(): | |
for x, target in test_loader: | |
output = model(x.to(device, dtype=torch.float)) | |
test_loss.append(loss_f(output, target.to(device, dtype=torch.float)).item()) | |
predictions.append(output.cpu()) | |
targets.append(target.cpu()) | |
predictions = np.vstack(predictions) | |
targets = np.vstack(targets) | |
score = roc_auc_score(targets, predictions) | |
test_loss = np.mean(test_loss) | |
print('\nTest set: Average loss: {:.6f}, roc auc: {:.4f}\n'.format(test_loss, score)) | |
return test_loss, score | |
lass Net(nn.Module): | |
"""Build the nn network based on pretrained resnet models. | |
Args: | |
base_model: resnet34\resnet50\etc from pretrained models | |
n_features: n features from last pooling layer | |
""" | |
def __init__(self, base_model, n_features): | |
super(Net, self).__init__() | |
self.layer0 = nn.Sequential(*list(base_model.children())[:4]) | |
self.layer1 = nn.Sequential(*list(base_model.layer1)) | |
self.layer2 = nn.Sequential(*list(base_model.layer2)) | |
self.layer3 = nn.Sequential(*list(base_model.layer3)) | |
self.layer4 = nn.Sequential(*list(base_model.layer4)) | |
self.dense1 = nn.Sequential(nn.Linear(n_features, 128)) | |
self.dense2 = nn.Sequential(nn.Linear(128, 64)) | |
self.classif = nn.Sequential(nn.Linear(64, 1)) | |
def forward(self, x): | |
x = self.features(x) | |
x = F.avg_pool2d(x, 7) | |
x = x.view(x.size(0), -1) | |
x = self.dense1(x) | |
x = self.dense2(x) | |
x = self.classif(x) | |
x = torch.sigmoid(x) | |
return x | |
def features(self, x): | |
x = self.layer0(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment