Last active
August 25, 2019 17:12
-
-
Save devforfu/9acebd780215efe43b8b5d69ba0f3f9c to your computer and use it in GitHub Desktop.
Catalyst example. (Doesn't work as expected).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
from pdb import set_trace | |
from multiprocessing import cpu_count | |
from pprint import pprint as pp | |
from imageio import imread | |
import numpy as np | |
import pandas as pd | |
import PIL.Image | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import torchvision | |
import torchvision.transforms as T | |
from catalyst.contrib.schedulers import OneCycleLR | |
from catalyst.dl.callbacks import AccuracyCallback, AUCCallback, F1ScoreCallback | |
from catalyst.dl.runner import SupervisedRunner | |
import pretrainedmodels | |
from jupytools import auto_set_trace | |
set_trace = auto_set_trace() | |
os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
def list_files(folder): | |
dirname = os.path.expanduser(folder) | |
return [os.path.join(dirname, x) for x in os.listdir(dirname)] | |
def extract_labels(files): | |
regex = re.compile('.*_(\d+)\\.png$') | |
return [int(regex.match(os.path.basename(fn)).group(1)) for fn in files] | |
class ImageDataset(Dataset): | |
def __init__(self, files, train=True, tr=None): | |
regex = re.compile('.*_(\d+)\\.png$') | |
self.files = files | |
self.tr = tr | |
self.labels = extract_labels(files) | |
@property | |
def n_classes(self): | |
return len(np.unique(self.labels)) | |
def __len__(self): | |
return len(self.files) | |
def __getitem__(self, index): | |
x = PIL.Image.open(self.files[index]) | |
if self.tr is not None: | |
x = self.tr(x) | |
y = self.labels[index] | |
return x, y | |
def get_model(model_name, num_classes, pretrained='imagenet'): | |
model_fn = pretrainedmodels.__dict__[model_name] | |
model = model_fn(num_classes=1000, pretrained=pretrained) | |
dim_feats = model.last_linear.in_features | |
model.last_linear = nn.Linear(dim_feats, num_classes) | |
return model | |
model_name = 'resnet50' | |
params = pretrainedmodels.pretrained_settings[model_name]['imagenet'] | |
pp(params) | |
data_tr = T.Compose([ | |
T.Resize((224, 224)), | |
T.ToTensor(), | |
T.Normalize(params['mean'], params['std']) | |
]) | |
bs = 128 | |
num_epochs = 1 | |
trn_files = list_files('~/data/tmp/train') | |
trn_ds = ImageDataset(trn_files, tr=data_tr) | |
trn_dl = DataLoader(trn_ds, batch_size=bs, num_workers=cpu_count()) | |
tst_files = list_files('~/data/tmp/test') | |
tst_ds = ImageDataset(tst_files, tr=data_tr) | |
tst_dl = DataLoader(tst_ds, batch_size=bs, num_workers=cpu_count()) | |
from collections import OrderedDict | |
loaders = OrderedDict() | |
loaders['train'] = trn_dl | |
loaders['valid'] = tst_dl | |
resnet = get_model(model_name, trn_ds.n_classes) | |
for param in resnet.parameters(): | |
param.requires_grad = False | |
resnet.last_linear.weight.requires_grad = True | |
for param in resnet.layer4.parameters(): | |
param.requires_grad = True | |
loss_fn = nn.CrossEntropyLoss() | |
opt = torch.optim.Adam(resnet.parameters(), lr=0.0001) | |
logdir = '/tmp/logs/' | |
runner = SupervisedRunner() | |
runner.train( | |
model=resnet, | |
criterion=loss_fn, | |
optimizer=opt, | |
loaders=loaders, | |
logdir=logdir, | |
num_epochs=num_epochs, | |
callbacks=[ | |
AccuracyCallback(), | |
AUCCallback(), | |
F1ScoreCallback(activation='Softmax') | |
], | |
verbose=True | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment