Last active
September 10, 2018 14:51
-
-
Save sizhky/692988358321fcb0a791a2452daef2b5 to your computer and use it in GitHub Desktop.
ocr/Untitled.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T13:53:32.981796Z", | |
"start_time": "2018-09-09T13:53:32.957016Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T13:53:35.424506Z", | |
"start_time": "2018-09-09T13:53:33.562923Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from fastai import *\n", | |
"from fastai.imports import *\n", | |
"from fastai.conv_learner import *\n", | |
"from fastai.plots import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T13:53:35.456575Z", | |
"start_time": "2018-09-09T13:53:35.427427Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"DATA = 'iam-data/'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T13:53:35.724285Z", | |
"start_time": "2018-09-09T13:53:35.461113Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"ls {DATA}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Data Pre-processing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:24:03.004592Z", | |
"start_time": "2018-09-09T23:24:02.782613Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# truth = open(f'{DATA}words.txt').readlines()[18:] # skip header\n", | |
"# truth = {l.split()[0]: l.split()[-1] for l in truth}\n", | |
"# print(len(truth))\n", | |
"\n", | |
"# def fname(fpath): return fpath.split('/')[-1][:-4]\n", | |
"# gt = lambda x: truth[x]\n", | |
"\n", | |
"# folder = f'{DATA}words/'\n", | |
"# fpaths = glob(f'{folder}/*')\n", | |
"\n", | |
"# sample = np.random.choice(fpaths, size=30)\n", | |
"# sample_fnames = list(map(fname, sample))\n", | |
"# sample_targets = list(map(gt, sample_fnames))\n", | |
"# plots_from_files(sample, figsize=(30,20), rows=10, titles=sample_targets)\n", | |
"\n", | |
"# vocab = set()\n", | |
"# for i in truth.values():\n", | |
"# vocab = vocab.union(set(i))\n", | |
"# vocab = ''.join(sorted(list(vocab)))\n", | |
"\n", | |
"# with open(f'{DATA}vocab.txt', 'w') as f: f.write(vocab)\n", | |
"\n", | |
"# with open(f'{DATA}vocab.txt', 'r') as f: vocab = f.read()\n", | |
"# print('Vocab:', vocab)\n", | |
"\n", | |
"# for k in truth:\n", | |
"# v = truth[k]\n", | |
"# v = ' '.join([str(vocab.index(i)) for i in v])\n", | |
"# truth[k] = v\n", | |
"\n", | |
"# def fname(fpath): return fpath.split('/')[-1][:-4]\n", | |
"# gt = lambda x: truth[x]\n", | |
"\n", | |
"# folder = f'{DATA}words/'\n", | |
"# fpaths = glob(f'{folder}/*')\n", | |
"\n", | |
"# sample = np.random.choice(fpaths, size=5)\n", | |
"# sample_fnames = list(map(fname, sample))\n", | |
"# sample_targets = list(map(gt, sample_fnames))\n", | |
"# plots_from_files(sample, figsize=(10,5), rows=5, titles=sample_targets)\n", | |
"\n", | |
"# fnames = sorted(list(map(fname, fpaths)))\n", | |
"\n", | |
"# csv = pd.DataFrame({\n", | |
"# 'fname': fnames,\n", | |
"# 'truth': list(truth[fname] for fname in fnames)\n", | |
"# })\n", | |
"# print(csv.shape)\n", | |
"# csv.to_csv(f'{DATA}truth.csv', index=None)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Data Loading" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:24:43.706002Z", | |
"start_time": "2018-09-09T23:24:43.567533Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"with open(f'{DATA}vocab.txt', 'r') as f: vocab = f.read()\n", | |
"labels = pd.read_csv(f'{DATA}truth.csv')\n", | |
"print(labels.shape)\n", | |
"labels.head(3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:24:48.710204Z", | |
"start_time": "2018-09-09T23:24:48.654481Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(10)\n", | |
"val_idxs = np.random.choice(range(len(labels)), size=int(0.3*len(labels)), replace=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ls {DATA}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:43:16.345883Z", | |
"start_time": "2018-09-09T23:43:14.799129Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"stats = (np.array([0]*3), np.array([255.0]*3))\n", | |
"tfm = tfms_from_stats(stats, sz=64, crop_type=CropType.NO)\n", | |
"\n", | |
"data = ImageClassifierData.from_csv(f'{DATA}', f'words/', csv_fname=f'{DATA}truth.csv',\n", | |
" suffix='.png', tfms=tfm, val_idxs=val_idxs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:43:44.086916Z", | |
"start_time": "2018-09-09T23:43:44.040857Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"class HWDataSet:\n", | |
" def __init__(self, ds, df):\n", | |
" self.ds = ds\n", | |
" self.df = df\n", | |
" self.sz = ds.sz\n", | |
" def __len__(self): return len(self.ds)\n", | |
" def __getitem__(self, idx):\n", | |
" x, y = self.ds[idx]\n", | |
" return x, self.df.truth[idx]\n", | |
" \n", | |
"(lab_val, lab_tr), = split_by_idx(val_idxs, labels)\n", | |
"data.trn_dl.dataset = HWDataSet(data.trn_ds, lab_tr.reset_index(drop=True))\n", | |
"data.val_dl.dataset = HWDataSet(data.val_ds, lab_val.reset_index(drop=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:43:44.730968Z", | |
"start_time": "2018-09-09T23:43:44.674315Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def show(img, tr):\n", | |
" img = img.transpose(1,2,0)\n", | |
" plt.imshow(img*255)\n", | |
" title = ''.join([str(vocab[int(i)]) for i in tr.split()])\n", | |
" plt.title(title)\n", | |
" \n", | |
"idx = 342\n", | |
"show(*data.trn_ds[idx])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"heading_collapsed": true | |
}, | |
"source": [ | |
"#### Arch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:28:57.374237Z", | |
"start_time": "2018-09-09T23:28:57.326334Z" | |
}, | |
"hidden": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def conv_layer(ni, nf, ks=3, stride=1):\n", | |
" return nn.Sequential(\n", | |
" nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),\n", | |
" nn.BatchNorm2d(nf, momentum=0.01),\n", | |
" nn.LeakyReLU(negative_slope=0.1, inplace=True))\n", | |
"\n", | |
"class ResLayer(nn.Module):\n", | |
" def __init__(self, ni):\n", | |
" super().__init__()\n", | |
" self.conv1=conv_layer(ni, ni//2, ks=1)\n", | |
" self.conv2=conv_layer(ni//2, ni, ks=3)\n", | |
" \n", | |
" def forward(self, x): return x.add_(self.conv2(self.conv1(x)))\n", | |
"\n", | |
"class ConvHead(nn.Module):\n", | |
" def make_group_layer(self, ch_in, num_blocks, stride=1):\n", | |
" return [conv_layer(ch_in, ch_in*2,stride=stride)\n", | |
" ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]\n", | |
"\n", | |
" def __init__(self, num_blocks, nf=32):\n", | |
" super().__init__()\n", | |
" layers = [conv_layer(3, nf, ks=3, stride=1)]\n", | |
" for i,nb in enumerate(num_blocks):\n", | |
" layers += self.make_group_layer(nf, nb, stride=2-(i==1))\n", | |
" nf *= 2\n", | |
" layers += [nn.AdaptiveAvgPool2d((1, 32))]\n", | |
" self.layers = nn.Sequential(*layers)\n", | |
" \n", | |
" def forward(self, x): return self.layers(x).transpose(1, 3).squeeze(2)\n", | |
"\n", | |
" \n", | |
"class RNNHead(nn.Module):\n", | |
" def __init__(self, nh, c):\n", | |
" '''\n", | |
" nh: number of hiddens\n", | |
" c : number of output classes (vocab size)\n", | |
" '''\n", | |
" super().__init__()\n", | |
" self.lstm = nn.LSTM(128, nh, 2, batch_first=True, bidirectional=True)\n", | |
" self.lin = nn.Linear(nh*2, c+1)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" # import pdb; pdb.set_trace()\n", | |
" x, _ = self.lstm(x)\n", | |
" x = self.lin(x)\n", | |
" return x\n", | |
" \n", | |
" \n", | |
" \n", | |
"m = nn.Sequential(\n", | |
" ConvHead([1], nf=64),\n", | |
" RNNHead(25, 80)\n", | |
")\n", | |
"\n", | |
"m1 = ConvHead([1], nf=64)\n", | |
"m2 = RNNHead(25, 80)\n", | |
"\n", | |
"m = nn.Sequential(m1, m2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T05:03:42.890764Z", | |
"start_time": "2018-09-09T05:03:42.885745Z" | |
} | |
}, | |
"source": [ | |
"#### Learner" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:40:38.929372Z", | |
"start_time": "2018-09-09T23:40:38.895627Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"learner = Learner.from_model_data(m, data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-09-09T23:40:39.907996Z", | |
"start_time": "2018-09-09T23:40:39.519598Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"learner.summary()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"%debug" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"abc = m1(torch.tensor(data.trn_ds[:10][0]))\n", | |
"print(abc.shape)\n", | |
"x = m2(abc)\n", | |
"print(x.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"a = nn.LSTM(128, 256, 2, batch_first=True, bidirectional=True)\n", | |
"c = nn.Linear(512, 80)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"gist": { | |
"data": { | |
"description": "ocr/Untitled.ipynb", | |
"public": true | |
}, | |
"id": "" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment