Skip to content

Instantly share code, notes, and snippets.

@joshfp
Last active October 10, 2020 16:02
Show Gist options
  • Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model based on tabular data only"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.tabular import *\n",
"from fastai.metrics import accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_feather('tabular-df')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = df.drop('title', axis=1)\n",
"cont_cols = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6',\n",
" 'col7', 'col8', 'col9', 'col10', 'col11', 'col12'] # real columns names were replaced\n",
"cat_cols = list(set(df.columns) - set(cont_cols) - {'condition'})\n",
"valid_sz = 10000\n",
"valid_idx = range(len(df)-valid_sz, len(df))\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"data = (TabularList.from_df(df, cat_cols, cont_cols, procs=procs)\n",
" .split_by_idx(valid_idx)\n",
" .label_from_df(cols='condition')\n",
" .databunch())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"learn = get_tabular_learner(data, layers=[64], ps=[0.5], emb_drop=0.05, metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 00:38\n",
"epoch train_loss valid_loss accuracy\n",
"1 0.253736 0.253902 0.896600 (00:08)\n",
"2 0.193293 0.231615 0.909400 (00:07)\n",
"3 0.144753 0.227158 0.913200 (00:07)\n",
"4 0.092080 0.244881 0.914300 (00:07)\n",
"5 0.067578 0.248254 0.915600 (00:07)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(5, 1e-2, wd=1e-6)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TabularModel(\n",
" (embeds): ModuleList(\n",
" (0): Embedding(4, 3)\n",
" (1): Embedding(33283, 50)\n",
" (2): Embedding(5, 3)\n",
" (3): Embedding(5, 3)\n",
" (4): Embedding(8, 5)\n",
" (5): Embedding(3, 2)\n",
" (6): Embedding(3481, 50)\n",
" (7): Embedding(286, 50)\n",
" (8): Embedding(26, 14)\n",
" (9): Embedding(10492, 50)\n",
" (10): Embedding(304, 50)\n",
" (11): Embedding(30, 16)\n",
" (12): Embedding(1461, 50)\n",
" (13): Embedding(570, 50)\n",
" (14): Embedding(300, 50)\n",
" (15): Embedding(3, 2)\n",
" (16): Embedding(3, 2)\n",
" )\n",
" (emb_drop): Dropout(p=0.05)\n",
" (bn_cont): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=462, out_features=64, bias=True)\n",
" (1): ReLU(inplace)\n",
" (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.5)\n",
" (4): Linear(in_features=64, out_features=2, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:fastai]",
"language": "python",
"name": "conda-env-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@joshfp
Copy link
Author

joshfp commented Jun 29, 2019

make gist public

@kontrabas380
Copy link

kontrabas380 commented Jan 29, 2020

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

@ascientist
Copy link

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

Same problem here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment