Last active
January 24, 2018 14:15
-
-
Save poppingtonic/8cbc1bd8be6a77c6efde53abc62dfeb8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.imports import * | |
from fastai.transforms import * | |
from fastai.conv_learner import * | |
from fastai.model import * | |
from fastai.dataset import * | |
from fastai.sgdr import * | |
from fastai.plots import * | |
sz=224 # image size | |
# architecture, from https://github.com/facebookresearch/ResNeXt | |
arch=resnext50 | |
# batch size | |
bs=64 | |
PATH = 'data/spiderscorpions/' | |
# Enable data augmentation, and precompute=True | |
# transforms_side_on flips the image along the vertical axis | |
# max_zoom: 1.1 makes images up to 10% larger | |
tfms = tfms_from_model(arch, sz, aug_tfms=transforms_side_on, max_zoom=1.1) | |
data=ImageClassifierData.from_paths(PATH,tfms=tfms) | |
learn = ConvLearner.pretrained(arch, data, precompute=True) | |
# Use lr_find() to find highest learning rate where loss is still clearly improving | |
learn.lr_find() | |
# check the plot to find the learning rate where the losss is still improving | |
learn.sched.plot() | |
# assuming the optimal learning rate is 0.01, train for 3 epochs | |
learn.fit(0.01, 3) | |
# train last layer with data augmentation (i.e. precompute=False) for 2-3 epochs with cycle_len=1 | |
learn.precompute=False | |
learn.fit(1e-2, 3, cycle_len=1) | |
# unfreeze all layers, thus opening up resnext50's original ImageNet weights for the | |
# features in the two spider and scorpion classes | |
learn.unfreeze() | |
lr = 0.01 | |
# fastai groups the layers in all of the pre-packaged pretrained convolutional networks into three groups | |
# retrain the three layer groups in resnext50 using these learning rates for each group | |
# We set earlier layers to 3x-10x lower learning rate than next higher layer | |
lrs = np.array([lr/9, lr/3, lr]) | |
learn.fit(lrs, 3) | |
# Use lr_find() again | |
learn.lr_find() | |
learn.sched.plot() | |
learn.fit(1e-2, 3, cycle_len=1, cycle_mult=2) | |
log_preds,y = learn.TTA() | |
preds = np.mean(np.exp(log_preds),0) | |
accuracy(log_preds, y) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment