Skip to content

Instantly share code, notes, and snippets.

@joshfp
Last active October 10, 2020 16:02
Show Gist options
  • Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model based on tabular data + NLP (title)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.tabular import *\n",
"from fastai.metrics import accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_feather('tabular-df')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# title (from NLP model)\n",
"df['title_isnew_prob'] = pd.read_feather('title-df')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"cont_cols = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6',\n",
" 'col7', 'col8', 'col9', 'col10', 'col11', 'col12', 'title_isnew_prob'] \n",
"cat_cols = list(set(df.columns) - set(cont_cols) - {'condition'})\n",
"valid_sz = 10000\n",
"valid_idx = range(len(df)-valid_sz, len(df))\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"data = (TabularList.from_df(df, cat_cols, cont_cols, procs=procs)\n",
" .split_by_idx(valid_idx)\n",
" .label_from_df(cols='condition')\n",
" .databunch())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"learn = get_tabular_learner(data, layers=[64], ps=[0.5], emb_drop=0.05, metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl4VdW5+PHvm4mEkISEnEAgILMMMgecB6hatBa01oHapw612gG91mqv/tpr78Xb0Q63trb3KrVaLU60VVQsWmdRhCBzEAhhSoDMIfP8/v44O3AMgQQ4O/ucnPfzPOfh7HXWPuddZHiz1tprbVFVjDHGmOOJ8joAY4wxoc+ShTHGmC5ZsjDGGNMlSxbGGGO6ZMnCGGNMlyxZGGOM6ZIlC2OMMV2yZGGMMaZLliyMMcZ0KcbrAIIlPT1dhw8f7nUYxhgTVtauXVuqqr6u6rmaLERkLvBbIBpYrKo/6/D6b4DZzmFfIENV+zuv3Qj80Hntv1X1yeN91vDhw8nJyQlm+MYY0+uJyJ7u1HMtWYhINPAIcAlQAKwRkWWqmtteR1W/G1D/DmCa8zwN+BGQDSiw1jm3wq14jTHGHJubcxazgDxVzVfVJuBZYP5x6i8AnnGefx54Q1XLnQTxBjDXxViNMcYch5vJYgiwL+C4wCk7ioicBowA3jqRc0XkNhHJEZGckpKSoARtjDHmaG4mC+mk7Fj7oV8PLFXV1hM5V1UfVdVsVc32+bqcnzHGGHOS3EwWBcDQgOMsYP8x6l7PkSGoEz3XGGOMy9xMFmuAMSIyQkTi8CeEZR0ricjpQCrwUUDxCuBSEUkVkVTgUqfMGGOMB1y7GkpVW0RkIf5f8tHA46q6RUQWATmq2p44FgDPasAt+1S1XEQexJ9wABaparlbsRpjjDk+6S23Vc3OztaTWWfR1qb87J+f8tUzT2PYgL4uRGY6U9PYQu7+KpITYkjrG0f/vnHExdiGAsb0NBFZq6rZXdXrNSu4T9buslqeXb2XZz7ey0++NIkvThkc1PfP2V1OSXUjl03KDOr79iRV5fXcIv7wdh7D0xP5/txxDOmfcFLvtangEEtW72XZ+kJqm1o/89qQ/gl85cxhfGXWMFIT44IRujEmSCK+ZwFQUFHHnc+s45O9lSyYNZQHrphIQlz0Kcf08ob93P38eppblWuzs1g0/wziY0/9fXvS5sJDPPhKLh/vKmf4gL4cONSACHzzwlHcfsGobv8/vbe9hIdWbGNT4SHiY6O4YvJg5k4cRGNLGxV1TVTUNrFqVxkr88qIj43iqmlZ3Hr+CEb5+rncQmMiW3d7FpYsHM2tbfzmje388d2djPL1Y96UwYwdmMTpg5IYltaX6KjOruY9tqdW7eGBlzYzc3ga2ael8od3djJxcDJ/vGFGSA93tbS28enBalbvKufDnWW8+WkRaX3j+O4lY7l+5lCKqhv5yfKtvLrxAEP6J3D5pEGMz0xmfGYyozP6ERv92aGkoqoGFr2Sy6sbD3DagL7ccu4Irpw2hJSE2E4/f9vBav68chd/X1eIAD/90iS+ND2rB1puTGSyZHGS3t9Rwn8u28LOktrDZX1iohgzsJ8/eQxMYsrQ/pw1ckCn56sqv3srj1+/sZ2Lxw/k91+ZRnxsNG9uLeK7z60H4N8uHssZg5MZOzCJ1MQ4VJUDhxrYUVzDzuIaBvSLY9rQVIamJSByYknqZLW2KQ++kssLOfsODw8N6Z/AF6cM5tuzR5Ec/9lf7h/nl/Gr17ezvqCSppY2AGKjhZHp/Q7/Xwnwf+/l09TaxsLZo7ntgpHd7lkVVzdw5zPrWJVfzs3nDuf/XT7+qERkjDl1lixOUW1jC3nFNWwrqmbbwWq2F/kfRVWNACyYNYz/nDeBPjHRnznnR8u2sHRtAVdPz+LnV08iJuAX3N6yOr69ZC2bC6sOl6X3i6OxuY3qxpajYkhLjGNKVgpnjRzAuaPTmZCZTNQJ9nC6o7VN+fe/bWTp2gKunDqYOeMHkn1aKoO7MS/R0tpGfmktWw9UsfVANTuKqtleXM2+8noALhzrY9H8iZw2IPGE42pubeMny7fy55W7OWtkGr//ynTS+/U54fcxxhybJQuXVNY18eh7+fzhnZ3MOC2VP94wnYzkeNbvq+SuZ9exp7yOO2aP5q6Lx3b6i729F7G9qJodRTXsKK4mPjaaMQOTGJPRj1G+fhRVNbB+XyUb9lXyyd6Kw72c1L6xnDs6ndsvGMWkrJSgtCcwUdx18RjuunhsUN63rqmFspomslJPvXf0908KuO/vm0Bh1og0Zo/LYM64DJLjY8gvrSW/pIb8klqKqhooq21y5kCaSUmIZWhaAlmpfRncP4HGllYqapsor22mvrmFc0al84VJmUdNptc2tlBR10RWaugOFxoTLJYsXPbqxgPc88IGkhNiuGLyYJ74cDeDkuP59bVTOPMYQ1Qnq6iqgZV5pXyQV8rbnxZzqL6Zr551Gt+79PRjjv13VN/Uyi9WfMqq/HKmDfMPo80ansYvX98W9EThhu1F1SxdW8BbnxaTV1xz1Otx0VFkJPdhQGIcaYlxpPaNo6KuiX0V9RRU1NHQ7B8qS4iNJs1JDoWV9cRGCxeO9XHhWB95xTWs3VvB1gPVtLYp549J587PjWHm8LQebasxPcmSRQ/YeqCK257KYV95PfOnDmbR/DO6/cv7ZB2qb+bXr2/jqVV7SEuM4/9dPp75U4ccdwJ+w75KvvvcevJLa5k5PJWtB6qpCRj2CvVE0dHesjre2V5Mc6sy0pfIqPR+DElNOOb/gapSUddMQmz04au3VJXcA1W8uK6QZRv2U1TVSN+4aKYO7U/2aanERkfxxIe7Katt4swRadx6/kimDetvw2Cm17Fk0UMO1Tezvai6x//63Fx4iB+8uJkN+yrJSk3ghjNP49rsLAY4v8xUlbLaJp76aA+/fzuPgUl9+OU1UzhndDotrW3kHqhiVX4ZA5PjmT+1082AI0Zrm7KvvI6s1ITPzDHVN7WyZPVe/u/dnRRX++eqfEl9GJ+ZzHmjB3DjOcM/M2dlTDiyZBEB2tqUf245yF8+2s2q/HLiYqI4d9QAyuua2VVSQ1WDv/fwpWlD+NG8ia73enqrhuZW1u6pYOuBKnIPVJG7v4pPD1Yz0pfIj6+cxNmjgjvsaExPsmQRYbYXVfP0qj18sKOUzP7xjEzvx4j0RCZnpZBtY+5B9+72En744ib2ldfz5RlZ/PvccfiSbIjKhB9LFsa4rL6plYff2sFj7+XTqsrpA5M4c0Qas0YMYNKQlOPOoxgTKixZGNND8oqrWb7pIKt3lbN2TwX1zf5FjbHRwtC0vowYkMi8qYOZN2Vwjy2yNKa7LFkY44Hm1jY2Fx5ie1E1u0rr2F1ay9aDVewpq2P26T5+fNWkbi12NKanWLIwJkS0tilPfribh1ZsIzpKuO+ycXxl1jBXVuMbc6K6myxssx1jXBYdJdxy3ghW3HUBU4am8MMXN3PF7z7gn5sP0tbWO/5YM72fJQtjesiwAX15+utn8pvrplDX1MI3n17L5Q+/z2ubDljSMCHP1WQhInNFZJuI5InIfceoc62I5IrIFhFZElD+C6dsq4g8LDYzaHoBEeGqaVn86+4L+fW1U2hqaeNbf/2E6x79iG0Hq70Oz5hjci1ZiEg08AhwGTABWCAiEzrUGQPcD5yrqhOBu5zyc4BzgcnAGcBM4EK3YjWmp8VER/Gl6Vm8cfeF/PzqSeQV13D5w+/zk+Vbqe1kB2JjvOZmz2IWkKeq+araBDwLzO9Q5xvAI6paAaCqxU65AvFAHNAHiAWKXIzVGE9ERwnXzRzGW9+7iGtmZPHoe/lc8ut32bCv0uvQjPkMN5PFEGBfwHGBUxZoLDBWRFaKyCoRmQugqh8BbwMHnMcKVd3a8QNE5DYRyRGRnJKSElcaYUxPSE2M42dXT+Zv3zqbqCjhhsUfs2Z3uddhGXOYm8miszmGjrN4McAY4CJgAbBYRPqLyGhgPJCFP8HMEZELjnoz1UdVNVtVs30+X1CDN8YLM05L44Vvnk1GUh++9qfVfLCj1OuQjAHcTRYFwNCA4yxgfyd1XlLVZlXdBWzDnzyuAlapao2q1gCvAWe5GKsxISMzJYHnbj/bf8/yJ9fw1qc2Amu852ayWAOMEZERIhIHXA8s61DnRWA2gIik4x+Wygf2AheKSIyIxOKf3D5qGMqY3sqX1IdnbzuLcYOSuO0va/lwp/UwjLdcSxaq2gIsBFbg/0X/vKpuEZFFIjLPqbYCKBORXPxzFPeqahmwFNgJbAI2ABtU9WW3YjUmFPXvG8fTt57J8PRE7liyjv2V9V6HZCKYbfdhTIjLK67hykdWMsqXyHO3n018rN1wyQSPbfdhTC8xOqMfv7xmChsKDvFfL2/xOhwToSxZGBMG5p4xiO/MHsUzq/fxzOq9XodjIpAlC2PCxN2XnM75Y9L50UtbWLunwutwTISxZGFMmIiOEn63YBqZ/eO5/am1NuFtepQlC2PCSP++cSz+WjYNza3c9lQO9U2tXodkIoQlC2PCzJiBSTy8YCpb9ldxz9IN9JYrGk1os2RhTBiaM24g/z53HK9uPMDv38rzOhwTAWK8DsAYc3Juv2Ak2w9W86s3tjNsQF/mT+24T6cxwWPJwpgwJSL89OpJFFbWc88LG/D168M5o9O9Dsv0UjYMZUwY6xMTzaNfy2ZEeiK3P7WWrQeqvA7J9FKWLIwJcykJsTxx8ywS+8Rw059XU2iX1PZ6qsq2g9Usfj+fGx9fzR3PrHP9My1ZGNMLDO6fwBO3zKSusZWvP7GGhma7pLa3ennDfs78yZt8/n/e479f3UpBRR3D0hJc/1ybszCmlxg3KJmHF0zj5ifW8PCbO/j+3HFeh2Rc8NL6/bQp/PzqSZw3xseQ/u4nCrCehTG9yuxxGVybncX/vrvT7uPdSxVW1jMlK4XrZg7rsUQBliyM6XV+8IUJZCTFc+/SDTS22HBUb1NQUceQ1J5LEu0sWRjTy6QkxPLTqyexvaiGh9/c4XU4JoiqGpqpbmghy5KFMSYYZp+ewTUzsvjfd/PZWGDDUb1FYYX/Srch/fv2+Ge7mixEZK6IbBORPBG57xh1rhWRXBHZIiJLAsqHicjrIrLVeX24m7Ea09v88IoJ+Pr14e7nN1Db2OJ1OCYICtqTRW/qWYhINPAIcBkwAVggIhM61BkD3A+cq6oTgbsCXv4L8JCqjgdmAcVuxWpMb5SSEMuvr51CfkkNP/jHJttwsBcorKgD6HXDULOAPFXNV9Um4Flgfoc63wAeUdUKAFUtBnCSSoyqvuGU16hqnYuxGtMrnTM6nbsvGcuL6/fz14/tDnvhrrCynvjYKAYkxvX4Z7uZLIYA+wKOC5yyQGOBsSKyUkRWicjcgPJKEfm7iKwTkYecnspniMhtIpIjIjklJSWuNMKYcPfti0Zz0ek+Fr2ca/MXYa6gop7B/RMQkR7/bDeTRWet6dgPjgHGABcBC4DFItLfKT8fuAeYCYwEbjrqzVQfVdVsVc32+XzBi9yYXiQqSvjNtVNJ7xfHt57+hMq6Jq9DMiepsLKerNSen9wGd5NFATA04DgL2N9JnZdUtVlVdwHb8CePAmCdM4TVArwITHcxVmN6tdTEOB65YTrF1Q38x0tbvA7HnKTCivoeXYgXyM1ksQYYIyIjRCQOuB5Y1qHOi8BsABFJxz/8lO+cmyoi7d2FOUCui7Ea0+tNG5bKNy8cxcsb9tvutGGorqmFstomTya3wcVk4fQIFgIrgK3A86q6RUQWicg8p9oKoExEcoG3gXtVtUxVW/EPQb0pIpvwD2k95lasxkSKW88bSVKfGH77L1usF272O7sJe5UsXN1IUFWXA8s7lD0Q8FyBu51Hx3PfACa7GZ8xkSalbyw3nzeCh9/cwZb9h5g4OMXrkEw37Tu8IK+X9SyMMaHp6+eNICneehfhptDDBXlgycKYiJOSEMvXzxvB67lFbC485HU4ppsKK+uJjRYykuI9+XxLFsZEoJvPdXoXttFg2CioqCczJYHoqJ5fYwGWLIyJSCkJsdx63kjesN5F2CisqPNsvgIsWRgTsW4+bzgpCbH896u5tLXZvlGhzr8gz5KFMaaHJcfHcv9l41iVX86S1bZvVChrbGmlqKrRs8ltsGRhTES7buZQzhudzk+Xb6XQuY7fhJ4DlQ2Ad5fNgiULYyKaiPDTL01Cgfv+ttG2MQ9RhYcX5HmzLxRYsjAm4g1N68t9l43j/R2lvLC2wOtwTCcKPLyPRTtLFsYYvnrmacwakcaDr+Ry8FCD1+GYDgor6okSGJTizRoLsGRhjMG/jfkvrp5MY0sbD79lay9CTUFlPYOS44mN9u5XtiULYwwAw9MTuWJyJsvW77d7doeYgop6T6+EAksWxpgAC2YNo6axhVc3HfA6FBPAy/tYtLNkYYw5LPu0VEb6EnnW1l2EjJbWNg5WNXh6JRRYsjDGBBARrp85lE/2VrK9qNrrcAxwsKqB1ja1YShjTGj50vQsYqOF59bs8zoUQ8DW5DYMZYwJJen9+nDJhIH8/ZMCGltavQ4n4hVUeHuHvHauJgsRmSsi20QkT0TuO0ada0UkV0S2iMiSDq8li0ihiPzezTiNMZ91/cxhVNQ18/qWIq9DiXjtq7cH99aehYhEA48AlwETgAUiMqFDnTHA/cC5qjoRuKvD2zwIvOtWjMaYzp03Op0h/RN4do1NdHvtwKEGBiTGER8b7WkcbvYsZgF5qpqvqk3As8D8DnW+ATyiqhUAqlrc/oKIzAAGAq+7GKMxphNRUcJ1M4eyMq+MvWV1XocT0YqqGshI9m7ldjs3k8UQIHCGrMApCzQWGCsiK0VklYjMBRCRKOBXwL3H+wARuU1EckQkp6SkJIihG2Ouyc4iOkp44sPdXocS0YqqGhiU3MfrMFxNFp3d+6/jlpYxwBjgImABsFhE+gPfBpar6nEvx1DVR1U1W1WzfT5fEEI2xrTLTElg/tTBLFm9h9KaRq/DiVhFVQ2e7gnVzs1kUQAMDTjOAvZ3UuclVW1W1V3ANvzJ42xgoYjsBn4JfE1EfuZirMaYTnxn9mgaW9p47P18r0OJSM2tbZTWNJGR1LuTxRpgjIiMEJE44HpgWYc6LwKzAUQkHf+wVL6q3qCqw1R1OHAP8BdV7fRqKmOMe0b5+vHFyYN56qM9VNQ2eR1OxCmu9vfoenXPQlVbgIXACmAr8LyqbhGRRSIyz6m2AigTkVzgbeBeVS1zKyZjzIlbOGc0dU2tPL5yl9ehRJz27eIHhcAEd4ybb66qy4HlHcoeCHiuwN3O41jv8QTwhDsRGmO6MnZgEpedMYgnVu7m1vNHkpIQ63VIEaO4yp8sMnr5BLcxppdYOGc01Y0tPLFyt9ehRJSDVaHTs7BkYYzp0sTBKVw8fiCPr9xFdUOz1+FEjINVDcRGC2mJcV6HYsnCGNM9d35uNIfqm3lq1R6vQ4kYxVWNZCTFI9LZSoSeZcnCGNMtk7P6c+FYH4vf30Vdk91JryccPBQaayzAkoUx5gTcMWc05bVNLPnY9ozqCf7V25YsjDFhJnt4GmePHMCj7+XT0Gzbl7vNvy+U91dCgSULY8wJumPOaIqrG3lhbYHXofRq1Q3N1Da1Ws/CGBOezh41gBmnpfK/7+ykqaXN63B6raL2y2ZtzsIYE45EhIVzRlNYWc+L6wq9DqfXKqryb/URCvtCgSULY8xJuGisj0lDUnjknTxaWq134YbDW31Yz8IYE65EhO/MHsWesjre22H3knFD++rtgTbBbYwJZ3PGDSSpT4zdp9slxVUNJMXH0DfO1S38uq1byUJERolIH+f5RSJyp3OTImNMhIqLieKicRn8a2sRrW0d72tmTtXBEFpjAd3vWfwNaBWR0cCfgBHAEteiMsaEhUsnDKS0pol1eyu8DqXXOVjVGDLzFdD9ZNHm3J/iKuB/VPW7QKZ7YRljwsFFp/uIjRZez7WhqGArrmoImSuhoPvJollEFgA3Aq84ZbapvTERLik+lnNGpbNiy0H8t6cxwdDaphRXNzIoJTQmt6H7yeJm/PfF/rGq7hKREcDT7oVljAkXl04cyJ6yOnYU13gdSq9RVtNIa5uG35yFquaq6p2q+oyIpAJJqvqzrs4Tkbkisk1E8kSk03toi8i1IpIrIltEZIlTNlVEPnLKNorIdSfUKmNMj7lk/EAAXt9y0ONIeo/DC/LCLVmIyDsikiwiacAG4M8i8usuzokGHgEuAyYAC0RkQoc6Y4D7gXNVdSJwl/NSHfA1p2wu8D929ZUxoSkjOZ5pw/rbvEUQhdId8tp1dxgqRVWrgC8Bf1bVGcDFXZwzC8hT1XxVbQKeBeZ3qPMN4BFVrQBQ1WLn3+2qusN5vh8oBnzdjNUY08MunTCIjQWH2F9Z73UovcLBENsXCrqfLGJEJBO4liMT3F0ZAuwLOC5wygKNBcaKyEoRWSUiczu+iYjMAuKAnZ28dpuI5IhITkmJrSI1xiuXTvQPRb1hvYugKK5qIEpgQAjcTrVdd5PFImAFsFNV14jISGBHF+d0dh/AjpdLxABjgIuABcDiwOEmJ0E9BdysqkdtQKOqj6pqtqpm+3zW8TDGK6N8/RjlS+T1XJu3CIaDhxrwJfUhJjp0Ntno7gT3C6o6WVW/5Rznq+rVXZxWAAwNOM4C9ndS5yVVbVbVXcA2/MkDEUkGXgV+qKqruhOnMcY7l04cxKr8cirrmrwOJewVVTeG1HwFdH+CO0tE/iEixSJSJCJ/E5GsLk5bA4wRkREiEgdcDyzrUOdFYLbzGen4h6Xynfr/AP6iqi+cSIOMMd647IxBtLapTXQHQdGhhpC6Egq6Pwz1Z/y/6Afjn3d42Sk7JmfF90L8w1dbgedVdYuILBKReU61FUCZiOQCbwP3qmoZ/rmRC4CbRGS985h6gm0zxvSgSUNSGJbWl1c2HvA6lLAXavtCgX/OoDt8qhqYHJ4QkbuOWduhqsuB5R3KHgh4rsDdziOwztPYoj9jwoqI8IXJmTz6Xj7ltU2khdDkbDhpaG7lUH1zSF0JBd3vWZSKyFdFJNp5fBUoczMwY0z4+cKkTFrblBW2QO+ktd9ONSMpdLb6gO4ni1vwDw0dBA4AX8a/BYgxxhw2cXAyI9ITeWVjx2tZTHeF2h3y2nX3aqi9qjpPVX2qmqGqV+JfoGeMMYeJCFdMzuSjnWWU1jR6HU5YKqr2/7+F2pzFqVzEe3fXVYwxkeYLkzNpU3htsw1FnYwip2cRrldDdaazRXfGmAh3+sAkRmf041UbijopxdUNxMdGkRwfGrdTbXcqycI2rzfGHKV9KOrjXeUUO5O1pvuKqhrJSIpHJLT+Hj9ushCRahGp6uRRjX/NhTHGHOWKyZmoDUWdlOLqBgYmh9aVUNBFslDVJFVN7uSRpKqh1UcyxoSM0RlJjBuUZFdFnYRip2cRakJnlypjTK9yxeRM1uyuoKCizutQwkpxdSMZ4dazMMaYk3XltCGIwN/WFnodStiobWyhprHFehbGmMiRldqXc0YNYOkn+2hrs+thuqPYWWMRdnMWxhhzKr48I4t95fWs3l3udShh4chWH9azMMZEkLkTM+nXJ4alawu8DiUsWM/CGBOREuKiuWJyJss3HaC2scXrcEJesfUsjDGR6prsLOqaWnl1k93noivF1Y30iYkiOSH0ViZYsjDGuGr6sFRGpifaUFQ3FFU1kJHcJ+RWb4MlC2OMy0SEq2dksXpXOXvKar0OJ6QVVzUyMASHoMDlZCEic0Vkm4jkich9x6hzrYjkisgWEVkSUH6jiOxwHje6Gacxxl1XT88iSrDeRReKqhtCckEeuJgsRCQaeAS4DJgALBCRCR3qjAHuB85V1YnAXU55GvAj4ExgFvAjEUl1K1ZjjLsGpcRz3hgff/+k0NZcHEdJiG71Ae72LGYBeaqar6pNwLPA/A51vgE8oqoVAKpa7JR/HnhDVcud194A5roYqzHGZVdOHUxhZT3r9lV6HUpIqmtqobqxJfJ6FsAQYF/AcYFTFmgsMFZEVorIKhGZewLnIiK3iUiOiOSUlJQEMXRjTLBdPGEgcdFRLLerojpVXOWssYjAnkVn0/kd+58xwBjgImABsFhE+nfzXFT1UVXNVtVsn893iuEaY9yUHB/L+WPSeW3TARuK6sTh1dsR2LMoAIYGHGcBHfcrLgBeUtVmVd0FbMOfPLpzrjEmzFw+KZP9hxpYX2BDUR0dWb0deT2LNcAYERkhInHA9cCyDnVeBGYDiEg6/mGpfGAFcKmIpDoT25c6ZcaYMNY+FPXqRhuK6ujIvlAR1rNQ1RZgIf5f8luB51V1i4gsEpF5TrUVQJmI5AJvA/eqapmqlgMP4k84a4BFTpkxJoylJNhQ1LGUVDcSFxNFSkKs16F0ytU15aq6HFjeoeyBgOcK3O08Op77OPC4m/EZY3re5ZMyefPTYtYXVDJ9mF0R366oqoGMpNBcvQ22gtsY08MunjCQ2GhhuQ1FfUZxdWPIzleAJQtjTA/zD0X5eG3zQfyDCwaO9CxClSULY0yPu3xSJoWV9ay3BXqHFVc3WrIwxphAlzhDUXZVlF99UyvVDS1k2DCUMcYcYUNRn1VcHdqXzYIlC2OMRz43PoPCynp2FNd4HYrniqpCe0EeWLIwxnhkzrgMAN76tLiLmr3f4Z5FiG71AZYsjDEeyUxJYHxmsiULAnoWIbqJIFiyMMZ4aM44H2v3VHCortnrUDxVXN1AXHQU/fuG5uptsGRhjPHQnHEDaW1T3t0R2bcYKK5qxBfCq7fBkoUxxkNTh/YnLTGOtyN8KKo4hG+n2s6ShTHGM9FRwoVjfbyzrZjWCN5YsKiqMaTnK8CShTHGY3PGZVBR18z6fRVeh+KZ4irrWRhjzHFdMNZHdJRE7FVRDc2tVDW0hPQaC7BkYYzxWEpCLNmnpfLm1shMFu333vaF8OptsGRhjAkBc8Zl8OnBavZX1nsdSo8rchbkWc/CGGO60L6a++1tkde7aO9ZhPK+UOByshCRuSKyTUTyROS+Tl6/SURKRGS987g14LVfiMgWEdkqIg9LKF/M01NAAAATxElEQVSAbIw5JaMz+jE0LYF/5RZ5HUqPK6v1J4v0fhGaLEQkGngEuAyYACwQkQmdVH1OVac6j8XOuecA5wKTgTOAmcCFbsVqjPGWiDB34iA+yCuNuNXcpTVNiEBqCK/eBnd7FrOAPFXNV9Um4FlgfjfPVSAeiAP6ALFA5P3JYUwEuWLyYJpblRVbDnodSo8qrWkktW8cMdGhPSvgZnRDgH0BxwVOWUdXi8hGEVkqIkMBVPUj4G3ggPNYoapbO54oIreJSI6I5JSURPZ2AcaEu8lZKQxL68vLG/d7HUqPKqtpJL1fnNdhdMnNZNHZHEPHJZovA8NVdTLwL+BJABEZDYwHsvAnmDkicsFRb6b6qKpmq2q2z+cLavDGmJ4lInxxSiYf7iyjrKbR63B6TFlNEwMSQ3u+AtxNFgXA0IDjLOAzfzKoapmqtn9XPAbMcJ5fBaxS1RpVrQFeA85yMVZjTAi4YvJgWtuU1zZHzlBUWW0TAyK8Z7EGGCMiI0QkDrgeWBZYQUQyAw7nAe1DTXuBC0UkRkRi8U9uHzUMZYzpXcYNSmJ0Rj9e3hA5Q1GlNY0hfyUUuJgsVLUFWAiswP+L/nlV3SIii0RknlPtTufy2A3AncBNTvlSYCewCdgAbFDVl92K1RgTGkSEL04ezOrd5RRVNXgdjusamlupbmhhQGLo9yxi3HxzVV0OLO9Q9kDA8/uB+zs5rxW43c3YjDGh6YopmfzmX9t5deMBbjlvhNfhuKq8tgmA9BBfkAe2gtsYE2JG+foxITM5Iq6KKqvxJ4tw6FlYsjDGhJwvThnMur2V7Cuv8zoUV5U6q7cHRPKchTHGnKwrJvuvfXll4wGPI3FXaXX7Vh/WszDGmBM2NK0vZ45I408f7KKmscXrcFxT1j5nYT0LY4w5OfdfPp7Smkb++E6e16G4pqymkfjYKPrGRXsdSpcsWRhjQtLUof25cupgHnt/FwUVvXPuon31djhsqm3JwhgTsr4/dxxRAr/45zavQ3FFaW1TWMxXgCULY0wIG9w/gdvOH8myDfv5ZG+F1+EEXWl1eKzeBksWxpgQd/uFo8hI6sODr+Si2nEv0vBWVtsYFvtCgSULY0yIS+wTwz2fP511eyt71aW0quqfs7CehTHGBMeXp2cxypfIkx/u9jqUoKmqb6GlTcNi9TZYsjDGhIGoKOGa7KHk7KlgV2mt1+EERWmY3Hu7nSULY0xYuGraEKIElq7d13XlMHBk9bYlC2OMCZqByfFcMNbH3z8ppLUt/Ce621dv2wS3McYE2TUzhnLgUAMf7iz1OpRT1n7rWEsWxhgTZJ8bn0FKQixL1xZ4HcopK61pQgTS+lqyMMaYoIqPjWbelMH8c/NBqhqavQ7nlJTVNpLaN46Y6PD4NexqlCIyV0S2iUieiNzXyes3iUiJiKx3HrcGvDZMRF4Xka0ikisiw92M1RgTHr48I4vGljZeDVhz8e72Er66+GO2F1V7GNmJKa1uCpvLZsHFZCEi0cAjwGXABGCBiEzopOpzqjrVeSwOKP8L8JCqjgdmAcVuxWqMCR+Ts1IYk9GPF3L2UV7bxN3PrefGx1fzQV4pT320x+vwui2cVm+Duz2LWUCequarahPwLDC/Oyc6SSVGVd8AUNUaVe2d204aY06IiHBNdhaf7K3kc796h2Ub9nPHnNFcPH4gr20+GDZXSoXT6m1wN1kMAQIviC5wyjq6WkQ2ishSERnqlI0FKkXk7yKyTkQecnoqnyEit4lIjojklJSUBL8FxpiQdOW0IfSNi2bYgEReufM8vnfp6Vw5bTClNY2s2V3udXjdUlrTiM+SBQCdbdDeMeW/DAxX1cnAv4AnnfIY4HzgHmAmMBK46ag3U31UVbNVNdvn8wUrbmNMiMtIiuej+z/HP751DuMGJQMwZ1wG8bFRLN8U+vtHNba0UtXQYnMWjgJgaMBxFrA/sIKqlqlqo3P4GDAj4Nx1zhBWC/AiMN3FWI0xYSYlIZaoqCN/k/aNi2H26RlhMRRVfnhBnvUsANYAY0RkhIjEAdcDywIriEhmwOE8YGvAuaki0t5dmAPkuhirMaYXuHxSJiXVoT8UVVYTXqu3wcVk4fQIFgIr8CeB51V1i4gsEpF5TrU7RWSLiGwA7sQZalLVVvxDUG+KyCb8Q1qPuRWrMaZ3CJehqNKa9n2hwidZxLj55qq6HFjeoeyBgOf3A/cf49w3gMluxmeM6V0S+xwZivrRFycSHRWa97Zu71mEyyaCYCu4jTG9TPtQVE4ID0WVHt4XypKFMcZ4Ys64DPrEhPZQVFltE31iokiMO2pFQMiyZGGM6VUCh6JC9aqo0ppG0vv1QSQ0h8k6Y8nCGNPrXDElk+LqRi586G3+felGXlpfSEl1Y9cn9pCymqawmtwGlye4jTHGC1+YlEnt1S289Wkxr20+wHM5+4iNFn63YDpzzxjkdXiU1TaSkRTvdRgnxJKFMabXERGumzmM62YOo7VNyd1fxY+WbWbhkk/4ww3TuXSitwmjtLqJ8c7K83Bhw1DGmF4tOkqYlJXCk7fM4owhKXxnySf8K7fIs3hU1dlxNnyuhAJLFsaYCJEUH8tfvj6LCYNT+NZf1/LmVm8SRlVDC82tGnZzFpYsjDERIzk+lr/cMovxmcl866+fsLOkpsdjKDu8ett6FsYYE7JSEmL5040z6RMTxX+8uBnVoy+vXbunnMc/2MWBQ/VB/ewN+yr5t2fXAzA8PTGo7+02m+A2xkQcX1Ifvj93HP/x4maWbdjP/KlHbrWzo6iaGx9fQ01jCw++msuZI9KYP3UIF48fiC/p5HoDlXVNPLRiG0tW78XXrw8PL5jG1KH9g9WcHiGdZdVwlJ2drTk5OV6HYYwJE61typf+sJLCynre/N5FpCTEUlHbxJV/WEltYyt/uGE6q/LLeHF9IfkltQBkpsQzaUgKk7NSuGLy4G71DnL3V/G1xz+moq6ZG88ezncvGUNSfKzbzes2EVmrqtld1rNkYYyJVJsLDzHv9x9ww5mn8cAXJ3Dj46vJ2V3Bs7efxfRhqYD/6qUt+6tYlV/GpsJDbCo4RH5pLUnxMTz2tWzOGjngmO+fV1zDdf/3EXExUSy+MZuJg1N6qmnd1t1kYcNQxpiIdcaQFL529nCe/Gg3hZX1fLizjF9dM+VwogD/mo0zhqRwxpAjv+j3lddxyxNr+NqfVvPr66ZwxeTBR733vvI6vrr4Y0SEv956JiN9/XqiSa6xCW5jTET73qVj8fXrw1ufFnPbBSO5ekZWl+cMTevLC988m6lD+7NwyToWv5//mdcPHmrgK4tXUd/cytO3zgr7RAE2DGWMMazeVc77O0q46+KxJ3QPjIbmVu5+fj3LNx3El9SHuOgoYqKFQ/XNtLQqf731TKaE+ES2DUMZY0w3zRqRxqwRaSd8XnxsNL9bMJ3pw3aRV1xDc6vS0tYGwI3nDA/5RHEiXE0WIjIX+C0QDSxW1Z91eP0m4CGg0Cn6vaouDng9Gf8tWf+hqgvdjNUYY05GdJRw6/kjvQ7Dda4lCxGJBh4BLgEKgDUiskxVcztUfe44ieBB4F23YjTGGNM9bk5wzwLyVDVfVZuAZ4H53T1ZRGYAA4HXXYrPGGNMN7mZLIYA+wKOC5yyjq4WkY0islREhgKISBTwK+De432AiNwmIjkiklNSUhKsuI0xxnTgZrLo7JKCjpdevQwMV9XJwL+AJ53ybwPLVXUfx6Gqj6pqtqpm+3y+Uw7YGGNM59yc4C4AhgYcZwH7AyuoalnA4WPAz53nZwPni8i3gX5AnIjUqOp9LsZrjDHmGNxMFmuAMSIyAv/VTtcDXwmsICKZqnrAOZyH/8onVPWGgDo3AdmWKIwxxjuuJQtVbRGRhcAK/JfOPq6qW0RkEZCjqsuAO0VkHtAClAM3uRWPMcaYk2cruI0xJoJF3K6zIlIC7OnkpRTgUBdlgcedPW//Nx0oPckQO4uju3W6asOx2tNZHTfbcLzXj/d/3vG4q+detCEY30eBz0+2DW5+H3U8Pt7PAoRmG7rTnlD7ee7usVs/C6epatdXCKlqr34Aj3ZVFnjc2fOAf3OCGUd363TVhmO15xhtca0Nx3v9eP/n3fkaeN2GYHwfBaMNbn4fdTPuwLKQa0N32hNqP8/dPe7pn4WOj0jYdfblbpS93MXzzt4jGHF0t05XbThWe45X52R09R7He/14/+cdj7vz/GSdbBuC8X3Unc/vipvfRx2Pe9PPQuDzUGtDd497+mfhM3rNMFRPEJEc7cbYXiizNoQGa4P3wj1+6Nk2RELPIpge9TqAILA2hAZrg/fCPX7owTZYz8IYY0yXrGdhjDGmSxGbLETkcREpFpHNJ3HuDBHZJCJ5IvKwiEjAa3eIyDYR2SIivwhu1EfFEfQ2iMh/ikihiKx3HpcHP/LPxOHK18F5/R4RURFJD17EncbhxtfhQWeDzfUi8rqIHH2T5yBxKf6HRORTpw3/EBFX7wLkUhuucX6O20TEtXmBU4n9GO93o4jscB43BpQf9+elSyd72VW4P4ALgOnA5pM4dzX+/asEeA24zCmfjX9DxD7OcUYYtuE/gXvC+evgvDYU/+4Be4D0cGsDkBxQ507gf8Ms/kuBGOf5z4Gfh+HXYDxwOvAO/i2HQip2J67hHcrSgHzn31Tneerx2tndR8T2LFT1PfxbjBwmIqNE5J8islZE3heRcR3PE5FM/D/IH6n/K/AX4Ern5W8BP1PVRuczisOwDT3KxTb8Bvg+R+90HHRutEFVqwKqJuJiO1yK/3VVbXGqrsK/kahrXGrDVlXd5mbcpxL7MXweeENVy1W1AngDmBuMn/mITRbH8Chwh6rOAO4B/tBJnSH4d9RtF3ifjrH4d8v9WETeFZGZrkbbuVNtA8BCZ/jgcRFJdS/UYzqlNoh/v7FCVd3gdqDHccpfBxH5sYjsA24AHnAx1s4E4/uo3S34/5LtacFsQ0/rTuydOdZ9hE65na7egzuciEg/4BzghYChvD6dVe2krP2vvhj8Xb+zgJnA8yIy0snkrgtSG/6I/3a26vz7K/w/7D3iVNsgIn2BH+AfBvFEkL4OqOoPgB+IyP3AQuBHQQ61U8GK33mvH+DfKPSvwYyxK8FsQ087XuwicjPwb07ZaGC5iDQBu1T1Ko7dnlNupyWLI6KASlWdGlgo/nuJr3UOl+H/ZRrYpQ68T0cB8HcnOawWkTb8e7f01G38TrkNqloUcN5jwCtuBtyJU23DKGAEsMH5QcsCPhGRWap60OXY2wXjeynQEuBVeihZEKT4ncnVK4DP9dQfTAGC/TXoSZ3GDqCqfwb+DCAi7wA3qerugCoFwEUBx1n45zYKONV2ujVpEw4PYDgBk0rAh8A1znMBphzjvDX4ew/tE0WXO+XfBBY5z8fi7w5KmLUhM6DOd4Fnw+3r0KHOblye4Hbp6zAmoM4dwNIwi38ukAv43P6/d/v7CJcnuE82do49wb0L/whHqvM8rTvt7DLGnvpChtoDeAY4ADTjz7pfx/8X6T+BDc43+gPHODcb2AzsBH7PkcWNccDTzmufAHPCsA1PAZuAjfj/8soMtzZ0qLMb96+GcuPr8DenfCP+PXyGhFn8efj/WFrvPFy7msvFNlzlvFcjUASsCKXY6SRZOOW3OP//ecDNJ/LzcryHreA2xhjTJbsayhhjTJcsWRhjjOmSJQtjjDFdsmRhjDGmS5YsjDHGdMmShenVRKSmhz9vsYhMCNJ7tYp/19nNIvJyVzu3ikh/Efl2MD7bmI7s0lnTq4lIjar2C+L7xeiRDfJcFRi7iDwJbFfVHx+n/nDgFVU9oyfiM5HFehYm4oiIT0T+JiJrnMe5TvksEflQRNY5/57ulN8kIi+IyMvA6yJykYi8IyJLxX/Phr+23xvAKc92ntc4mwFuEJFVIjLQKR/lHK8RkUXd7P18xJGNEvuJyJsi8on4708w36nzM2CU0xt5yKl7r/M5G0Xkv4L432gijCULE4l+C/xGVWcCVwOLnfJPgQtUdRr+XV5/EnDO2cCNqjrHOZ4G3AVMAEYC53byOYnAKlWdArwHfCPg83/rfH6X+/M4+xl9Dv+KeoAG4CpVnY7/Hiq/cpLVfcBOVZ2qqveKyKXAGGAWMBWYISIXdPV5xnTGNhI0kehiYELAjp7JIpIEpABPisgY/Dtyxgac84aqBt5zYLWqFgCIyHr8e/t80OFzmjiyEeNa4BLn+dkcuZfAEuCXx4gzIeC91+K/NwH49/b5ifOLvw1/j2NgJ+df6jzWOcf98CeP947xecYckyULE4migLNVtT6wUER+B7ytqlc54//vBLxc2+E9GgOet9L5z1KzHpkUPFad46lX1akikoI/6XwHeBj//S18wAxVbRaR3UB8J+cL8FNV/b8T/FxjjmLDUCYSvY7//hAAiEj7VtApQKHz/CYXP38V/uEvgOu7qqyqh/DfWvUeEYnFH2exkyhmA6c5VauBpIBTVwC3OPdHQESGiEhGkNpgIowlC9Pb9RWRgoDH3fh/8WY7k765+LeWB/gF8FMRWQlEuxjTXcDdIrIayAQOdXWCqq7DvwPp9fhvJJQtIjn4exmfOnXKgJXOpbYPqerr+Ie5PhKRTcBSPptMjOk2u3TWmB7m3M2vXlVVRK4HFqjq/K7OM8ZLNmdhTM+bAfzeuYKpkh68ba0xJ8t6FsYYY7pkcxbGGGO6ZMnCGGNMlyxZGGOM6ZIlC2OMMV2yZGGMMaZLliyMMcZ06f8DPa2e9trMnqQAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 01:10\n",
"epoch train_loss valid_loss accuracy\n",
"1 0.211267 0.210265 0.917200 (00:06)\n",
"2 0.186647 0.207828 0.916400 (00:06)\n",
"3 0.159084 0.217404 0.917100 (00:06)\n",
"4 0.114898 0.226102 0.921200 (00:07)\n",
"5 0.091307 0.221410 0.923400 (00:07)\n",
"6 0.088051 0.214281 0.924000 (00:07)\n",
"7 0.078463 0.230735 0.924700 (00:07)\n",
"8 0.073674 0.242582 0.924100 (00:07)\n",
"9 0.054637 0.244048 0.924100 (00:07)\n",
"10 0.055157 0.254020 0.923700 (00:07)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(10, 5e-3, wd=1e-6)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:fastai]",
"language": "python",
"name": "conda-env-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@joshfp
Copy link
Author

joshfp commented Jun 29, 2019

make gist public

@kontrabas380
Copy link

kontrabas380 commented Jan 29, 2020

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

@ascientist
Copy link

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

Same problem here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment