Created
June 21, 2019 15:24
-
-
Save radekosmulski/54dc35136133cfceb67aded2004d18c2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision import * | |
from fastai.script import * | |
from torch import nn | |
from fastai.metrics import top_k_accuracy | |
path = untar_data(URLs.CIFAR) | |
data = ImageDataBunch.from_folder(path, valid='test') | |
class block(nn.Module): | |
def __init__(self, n_in, n_out, two_d=True): | |
super().__init__() | |
self.op = nn.Conv2d(n_in, n_out, 3) if two_d else nn.Linear(n_in, n_out) | |
self.bn = nn.BatchNorm2d(n_out) if two_d else nn.BatchNorm1d(n_out) | |
def forward(self, x): | |
x = self.op(x) | |
x = F.relu(x) | |
x = self.bn(x) | |
return x | |
arch = SequentialEx( | |
block(3,32), | |
block(32,32), | |
nn.MaxPool2d(2), | |
block(32,32), | |
block(32,32), | |
nn.MaxPool2d(2), | |
Flatten(), | |
block(800, 800, False), | |
block(800, 800, False), | |
nn.Linear(800, 10) | |
) | |
def top_3_accuracy(preds, targs): return top_k_accuracy(preds, targs, 3) | |
learn = Learner(data, arch, metrics=[accuracy, top_3_accuracy]) | |
@call_parse | |
def train( | |
epochs: Param("Number of epochs to train", int)=1, | |
max_lr: Param("Maximum lr for one cycle", float)=1e-3 | |
): | |
learn.lr_find() | |
learn.recorder.plot() | |
learn.fit_one_cycle(epochs, max_lr) | |
learn.recorder.plot_losses() | |
loss, top_1, top_3 = learn.validate() | |
learn.save(f'{epochs}_{max_lr}_{loss:.2f}_{top_1:.2f}_{top_3:.2f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment