Created
November 20, 2018 06:57
-
-
Save drscotthawley/2288e92f23b02e7fd5352708fc6cd125 to your computer and use it in GitHub Desktop.
FastAICustomModelExample
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# A mixture of [@eslavich's post](https://forums.fast.ai/t/learner-layer-groups-parameter/30212) and the Lesson 5 lesson5-sgd-mnist.ipynb | |
# In[ ]: | |
get_ipython().run_line_magic('reload_ext', 'autoreload') | |
get_ipython().run_line_magic('autoreload', '2') | |
get_ipython().run_line_magic('matplotlib', 'inline') | |
# In[2]: | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataset import TensorDataset | |
from fastai import * | |
from fastai.vision import * | |
# In[21]: | |
class SimpleModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear1 = nn.Linear(1, 5) | |
self.linear2 = nn.Linear(5, 1) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = self.linear2(x) | |
return x | |
def generate_data(size): | |
x = np.random.uniform(size=(size, 1)) | |
y = x * 2.0 | |
return x.astype(np.float32), y.astype(np.float32) | |
x_train, y_train = generate_data(10000) | |
x_valid, y_valid = generate_data(1000) | |
x_train,y_train,x_valid,y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid)) | |
n,c = x_train.shape | |
x_train.shape, y_train.min(), y_train.max() | |
# In[22]: | |
bs=50 | |
train_ds = TensorDataset(x_train, y_train) | |
valid_ds = TensorDataset(x_valid, y_valid) | |
data = DataBunch.create(train_ds, valid_ds, bs=bs) | |
# In[23]: | |
x,y = next(iter(data.train_dl)) | |
x.shape,y.shape | |
# In[24]: | |
model = SimpleModel().cuda() | |
# In[25]: | |
model(x).shape | |
# In[26]: | |
loss_func = nn.MSELoss() | |
learn = Learner(data, SimpleModel(), loss_func=loss_func) | |
# In[27]: | |
learn.lr_find() | |
learn.recorder.plot() | |
# In[28]: | |
learn.fit_one_cycle(1, 1e-1) | |
# In[29]: | |
learn.recorder.plot_lr(show_moms=True) | |
# In[30]: | |
learn.recorder.plot_losses() | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment