Skip to content

Instantly share code, notes, and snippets.

@sizhky
Last active September 10, 2018 14:51
Show Gist options
  • Save sizhky/692988358321fcb0a791a2452daef2b5 to your computer and use it in GitHub Desktop.
Save sizhky/692988358321fcb0a791a2452daef2b5 to your computer and use it in GitHub Desktop.
ocr/Untitled.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T13:53:32.981796Z",
"start_time": "2018-09-09T13:53:32.957016Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T13:53:35.424506Z",
"start_time": "2018-09-09T13:53:33.562923Z"
}
},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.imports import *\n",
"from fastai.conv_learner import *\n",
"from fastai.plots import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T13:53:35.456575Z",
"start_time": "2018-09-09T13:53:35.427427Z"
}
},
"outputs": [],
"source": [
"DATA = 'iam-data/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T13:53:35.724285Z",
"start_time": "2018-09-09T13:53:35.461113Z"
}
},
"outputs": [],
"source": [
"ls {DATA}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Pre-processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:24:03.004592Z",
"start_time": "2018-09-09T23:24:02.782613Z"
}
},
"outputs": [],
"source": [
"# truth = open(f'{DATA}words.txt').readlines()[18:] # skip header\n",
"# truth = {l.split()[0]: l.split()[-1] for l in truth}\n",
"# print(len(truth))\n",
"\n",
"# def fname(fpath): return fpath.split('/')[-1][:-4]\n",
"# gt = lambda x: truth[x]\n",
"\n",
"# folder = f'{DATA}words/'\n",
"# fpaths = glob(f'{folder}/*')\n",
"\n",
"# sample = np.random.choice(fpaths, size=30)\n",
"# sample_fnames = list(map(fname, sample))\n",
"# sample_targets = list(map(gt, sample_fnames))\n",
"# plots_from_files(sample, figsize=(30,20), rows=10, titles=sample_targets)\n",
"\n",
"# vocab = set()\n",
"# for i in truth.values():\n",
"# vocab = vocab.union(set(i))\n",
"# vocab = ''.join(sorted(list(vocab)))\n",
"\n",
"# with open(f'{DATA}vocab.txt', 'w') as f: f.write(vocab)\n",
"\n",
"# with open(f'{DATA}vocab.txt', 'r') as f: vocab = f.read()\n",
"# print('Vocab:', vocab)\n",
"\n",
"# for k in truth:\n",
"# v = truth[k]\n",
"# v = ' '.join([str(vocab.index(i)) for i in v])\n",
"# truth[k] = v\n",
"\n",
"# def fname(fpath): return fpath.split('/')[-1][:-4]\n",
"# gt = lambda x: truth[x]\n",
"\n",
"# folder = f'{DATA}words/'\n",
"# fpaths = glob(f'{folder}/*')\n",
"\n",
"# sample = np.random.choice(fpaths, size=5)\n",
"# sample_fnames = list(map(fname, sample))\n",
"# sample_targets = list(map(gt, sample_fnames))\n",
"# plots_from_files(sample, figsize=(10,5), rows=5, titles=sample_targets)\n",
"\n",
"# fnames = sorted(list(map(fname, fpaths)))\n",
"\n",
"# csv = pd.DataFrame({\n",
"# 'fname': fnames,\n",
"# 'truth': list(truth[fname] for fname in fnames)\n",
"# })\n",
"# print(csv.shape)\n",
"# csv.to_csv(f'{DATA}truth.csv', index=None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:24:43.706002Z",
"start_time": "2018-09-09T23:24:43.567533Z"
}
},
"outputs": [],
"source": [
"with open(f'{DATA}vocab.txt', 'r') as f: vocab = f.read()\n",
"labels = pd.read_csv(f'{DATA}truth.csv')\n",
"print(labels.shape)\n",
"labels.head(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:24:48.710204Z",
"start_time": "2018-09-09T23:24:48.654481Z"
}
},
"outputs": [],
"source": [
"np.random.seed(10)\n",
"val_idxs = np.random.choice(range(len(labels)), size=int(0.3*len(labels)), replace=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ls {DATA}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:43:16.345883Z",
"start_time": "2018-09-09T23:43:14.799129Z"
}
},
"outputs": [],
"source": [
"stats = (np.array([0]*3), np.array([255.0]*3))\n",
"tfm = tfms_from_stats(stats, sz=64, crop_type=CropType.NO)\n",
"\n",
"data = ImageClassifierData.from_csv(f'{DATA}', f'words/', csv_fname=f'{DATA}truth.csv',\n",
" suffix='.png', tfms=tfm, val_idxs=val_idxs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:43:44.086916Z",
"start_time": "2018-09-09T23:43:44.040857Z"
}
},
"outputs": [],
"source": [
"class HWDataSet:\n",
" def __init__(self, ds, df):\n",
" self.ds = ds\n",
" self.df = df\n",
" self.sz = ds.sz\n",
" def __len__(self): return len(self.ds)\n",
" def __getitem__(self, idx):\n",
" x, y = self.ds[idx]\n",
" return x, self.df.truth[idx]\n",
" \n",
"(lab_val, lab_tr), = split_by_idx(val_idxs, labels)\n",
"data.trn_dl.dataset = HWDataSet(data.trn_ds, lab_tr.reset_index(drop=True))\n",
"data.val_dl.dataset = HWDataSet(data.val_ds, lab_val.reset_index(drop=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:43:44.730968Z",
"start_time": "2018-09-09T23:43:44.674315Z"
}
},
"outputs": [],
"source": [
"def show(img, tr):\n",
" img = img.transpose(1,2,0)\n",
" plt.imshow(img*255)\n",
" title = ''.join([str(vocab[int(i)]) for i in tr.split()])\n",
" plt.title(title)\n",
" \n",
"idx = 342\n",
"show(*data.trn_ds[idx])"
]
},
{
"cell_type": "markdown",
"metadata": {
"heading_collapsed": true
},
"source": [
"#### Arch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:28:57.374237Z",
"start_time": "2018-09-09T23:28:57.326334Z"
},
"hidden": true
},
"outputs": [],
"source": [
"def conv_layer(ni, nf, ks=3, stride=1):\n",
" return nn.Sequential(\n",
" nn.Conv2d(ni, nf, kernel_size=ks, bias=False, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(nf, momentum=0.01),\n",
" nn.LeakyReLU(negative_slope=0.1, inplace=True))\n",
"\n",
"class ResLayer(nn.Module):\n",
" def __init__(self, ni):\n",
" super().__init__()\n",
" self.conv1=conv_layer(ni, ni//2, ks=1)\n",
" self.conv2=conv_layer(ni//2, ni, ks=3)\n",
" \n",
" def forward(self, x): return x.add_(self.conv2(self.conv1(x)))\n",
"\n",
"class ConvHead(nn.Module):\n",
" def make_group_layer(self, ch_in, num_blocks, stride=1):\n",
" return [conv_layer(ch_in, ch_in*2,stride=stride)\n",
" ] + [(ResLayer(ch_in*2)) for i in range(num_blocks)]\n",
"\n",
" def __init__(self, num_blocks, nf=32):\n",
" super().__init__()\n",
" layers = [conv_layer(3, nf, ks=3, stride=1)]\n",
" for i,nb in enumerate(num_blocks):\n",
" layers += self.make_group_layer(nf, nb, stride=2-(i==1))\n",
" nf *= 2\n",
" layers += [nn.AdaptiveAvgPool2d((1, 32))]\n",
" self.layers = nn.Sequential(*layers)\n",
" \n",
" def forward(self, x): return self.layers(x).transpose(1, 3).squeeze(2)\n",
"\n",
" \n",
"class RNNHead(nn.Module):\n",
" def __init__(self, nh, c):\n",
" '''\n",
" nh: number of hiddens\n",
" c : number of output classes (vocab size)\n",
" '''\n",
" super().__init__()\n",
" self.lstm = nn.LSTM(128, nh, 2, batch_first=True, bidirectional=True)\n",
" self.lin = nn.Linear(nh*2, c+1)\n",
" \n",
" def forward(self, x):\n",
" # import pdb; pdb.set_trace()\n",
" x, _ = self.lstm(x)\n",
" x = self.lin(x)\n",
" return x\n",
" \n",
" \n",
" \n",
"m = nn.Sequential(\n",
" ConvHead([1], nf=64),\n",
" RNNHead(25, 80)\n",
")\n",
"\n",
"m1 = ConvHead([1], nf=64)\n",
"m2 = RNNHead(25, 80)\n",
"\n",
"m = nn.Sequential(m1, m2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T05:03:42.890764Z",
"start_time": "2018-09-09T05:03:42.885745Z"
}
},
"source": [
"#### Learner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:40:38.929372Z",
"start_time": "2018-09-09T23:40:38.895627Z"
}
},
"outputs": [],
"source": [
"learner = Learner.from_model_data(m, data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2018-09-09T23:40:39.907996Z",
"start_time": "2018-09-09T23:40:39.519598Z"
}
},
"outputs": [],
"source": [
"learner.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"%debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"abc = m1(torch.tensor(data.trn_ds[:10][0]))\n",
"print(abc.shape)\n",
"x = m2(abc)\n",
"print(x.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a = nn.LSTM(128, 256, 2, batch_first=True, bidirectional=True)\n",
"c = nn.Linear(512, 80)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"gist": {
"data": {
"description": "ocr/Untitled.ipynb",
"public": true
},
"id": ""
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment