Skip to content

Instantly share code, notes, and snippets.

@joshfp
Last active October 10, 2020 16:02
Show Gist options
  • Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model based on tabular data only"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai import *\n",
"from fastai.tabular import *\n",
"from fastai.metrics import accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_feather('tabular-df')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = df.drop('title', axis=1)\n",
"cont_cols = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6',\n",
" 'col7', 'col8', 'col9', 'col10', 'col11', 'col12'] # real columns names were replaced\n",
"cat_cols = list(set(df.columns) - set(cont_cols) - {'condition'})\n",
"valid_sz = 10000\n",
"valid_idx = range(len(df)-valid_sz, len(df))\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"data = (TabularList.from_df(df, cat_cols, cont_cols, procs=procs)\n",
" .split_by_idx(valid_idx)\n",
" .label_from_df(cols='condition')\n",
" .databunch())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"learn = get_tabular_learner(data, layers=[64], ps=[0.5], emb_drop=0.05, metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl4VNX5wPHvm5ksZCMrCSRkARL2PYCA4C6odUErRW0LarW21bba2mrtz1qt1i5qN61bcS0qLigqirhi2RIWAdmXBAhbNsgyWSbL+f0xg4SQMAEyc2eS9/M88zBz59yZ9zBJ3jn3nPteMcaglFJKnUiQ1QEopZTyf5oslFJKeaTJQimllEeaLJRSSnmkyUIppZRHmiyUUkp5pMlCKaWUR5oslFJKeaTJQimllEd2qwPoKAkJCSYjI8PqMJRSKqCsWrWqxBiT6Kldp0kWGRkZrFy50uowlFIqoIjIrva008NQSimlPNJkoZRSyiNNFkoppTzSZKGUUsojTRZKKaU80mShlFLKI00WSimlPNJkoZRSAezNVYXMWbHb6++jyUIppQLYm6sLeXN1odffR5OFUkoFMEddAxGh3i/GoclCKaUCWFVdA5GhNq+/jyYLpZQKYI66RiJCdGShlFLqBDrFYSgRmSoiW0Rku4jc1crzj4nIV+7bVhE53Oy5mSKyzX2b6c04lVIqEBljcDgbiArzfrLw2juIiA14HLgAKATyRGS+MWbjkTbGmNubtb8NGOm+Hwf8DsgBDLDKve8hb8WrlFKBpqa+kSZDwI8sxgLbjTE7jTFO4FXg8hO0vwZ4xX1/CrDIGFPmThCLgKlejDUg7Smr5vrncpn2xBKKKmqtDkcp5WNVtQ1A4CeLFGBPs8eF7m3HEZF0IBP49GT37YqamgwvLC1gyt8Wk5tfxpYDlVz15FLySxxWh6aU8qGqOley8MVqKG+mI2llm2mj7QzgDWNM48nsKyI3AzcDpKWlnUqMfq+2vpFlO0qprW/E2diEs6GJuSv3kFdwiMnZiTw0bQhlDieznsvj2/9eyvPXj2Voanerw1ZK+YCjzvUn0xerobz5DoVA72aPU4F9bbSdAfykxb5nt9j385Y7GWOeBp4GyMnJaSsRBSxjDD96eRWfbSk+Znt0mJ2/Xj2cq0alICKkxobzxi3j+d5/cpnx9DL+MG0IZ2X3IC4ixKLIlVK+cHRkEdjJIg/IEpFMYC+uhHBty0Yi0h+IBZY127wQeEhEYt2PLwTu9mKsfmnB+gN8tqWYn52XxdQhyYTYgwixBZEQGUq3kGOHnX0SI3nrxxOYOTuX219b69qWEMGo9FjOH9iDCwclExTU2oBNKRWoHHW+m7Pw2jsYYxpE5FZcf/htwGxjzAYRuR9YaYyZ7256DfCqMcY027dMRB7AlXAA7jfGlHkrVn9UUVvP79/dwOBe0dx2bj/sNs/TS0nRYbxz60S+2n2YVbsPsXrXYT7dXMQbqwrpnxTFbef14+IhPTVpKNVJOJzukUUgL50FMMYsABa02HZvi8f3tbHvbGC214Jz219ew2/eWs+Pz+nHmIw4b79duz2ycAslVXU8OzOnXYniiFC7jXF94hnXJx6AxibDe+v28c9Pt3PrnDVk9djGXRcN4LyBSR5fyxjDivwy4iNCyEqK8th+ze5DPPbxNmqcDYSH2IkMtRPdLZgbJma0a39vqqyt54P1BzhYUcsVI1PoHRduaTxKdYTOchgqIHTvFsz6veX8/eNtvPyDcVaHA8DaPYd5cfkuZo7PYFhqzGm9li1IuHxECt8a1osF6/fz90+2ceMLK7lmbBr/962BhLcxMbar1MFv3/6aL7eVAHB2/0RuntyH8X3iETl2ZFJeXc+fF25mTu5uEiND6ZsYyeFqJ4WHqjlQXsv8r/by2HdGcOHg5JOKvaiylrl5e0iJ7cbQlBj6JESc1KioscmweFsxb63ey0cbDlDX0ATAox9v5ezsRL43Pp1JWYmU19RTWuWkpKqOw9X1OJwN1DgbcTgbGJYSw5lZCScVt1K+4suls10+WYSH2Pnh5L48uGATeQVllo8uGhqb+M289fSICuUXF2Z32OvagoRLh/diyuBkHlm0hacX72RFfin/mDGSISlHV085G5p4evEO/vnpdoJtQfz2koHUOBt5YVkB1z6zgqEp3RmbGUdUmJ2osGCamgxPLd5BmcPJDRMzuf2C7GO+5ewvr+GWl1Zx80uruP38bG47tx8isHr3Ieas2MOK/FJuOasv3z0j/Zh495fXcO0zK45ZDhwRYmNEWgzXjk1nyuCkE464nA1N/GTOahZtPEhMeDDTc3ozbVQKydFhvJq7m1fy9nDD8ys9/r8FCTx+7SguGtrzJP63lfKNI3MW4cHeXzorzaYKAlpOTo5ZudLzL39rapyNTPrzp/RPjuK/PzijgyPz7EB5Lat2HWL17kOsyC/l670VPHHdKC724h+opdtLuH3uV5Q5nIxOj6WqroGKmgbKHE6q6hq4aEgyv7t0MMndwwDXEt63Vu/lxWUF7CmrxuFs/Oa1hveO4cErhhyTdJqrrW/kN/PW89bqvUzKSqCooo4tByuJDLWTHh/Ohn0VzByfzv99axB2WxCFh6q59pkVlDmczJ41hpjwYNYVlrO+8DCfby1mV2k1veO6ccPETKbn9D7uW1VdQyM/fnk1n2wu4p6LBzJzQgYh9mMTS31jEx9tOMiWAxXER4aSEBlKfGQIMeHBRITYiQi1EyRw4wsrWV9Yzn9m5TApK7FjPwSlTtP9727ktbzdbLj/1M9ZFpFVxpgcj+00Wbg8++VO/vD+Jub+cDxjM303uvjHJ9t4dNFWAELtQQxL7c6UwcnceGbmcYd7Otohh5MHF2xiV6mD6LBgorsFExVm55wBPTinf48T7tvQ2ERVXQPVzkaSo8M8Hh4yxjB7SQEPf7CJQT2juWZsGpcO70VYsI2HP9jEM1/mc2a/BO66aAA/fGkVFbX1vHTjOEb0PvYwXGOTYdHGgzzz5U5W7TpEdJida8am8f0JGaTEdKO2vpFbXl7F51uKeXDaEK4bl95GRO1TXl3Pd55exu6yal7+wThGpcV63kkpH/n1G+v4bEsRufecf8qvocniJLlGF5/RPznSZ6OL99ft5ydzVnPJsJ7cNKkPg3pGH/cNuLNpbDLYWkksc1fu4Z5566lvNMSEB/PyjePaHKkcsWrXIf7zv518+PUBRISpg5M5XONk6Y5S/jhtKDPGdsyJmkUVtVz91DIOV9cz94fj6Z9s7WS9Ukf8ZM5qNu2v4NNfnH3Kr9HeZNHl5yyO6BZi45az+vCH9zeRm1/W7tHF/vIaHv9sOwfKayl1OClzOHE2NDE5K5FLhvVkfN94gls5tr5hXzm/fH0to9JieHT6cELt3j/m6A9aSxQA03N6k5kQwb8/38GdU/ozsGe0x9canR7L6PTR7D1cw4vLCnhlxW4q6xr401XDmJ7T2+P+7dUjOoyXbxzHVf9eypVPLOE3lwzk2rFpXh/5KeWJo67BJyuhQEcWxzgyushOimTOTZ5HF5sPVDBrdh6Hqp30SYwkPiKEuIgQGpsMX2wtpqqugdjwYKYOSeay4SmMy4wjKEgoqarj8n8tockY3rl1Ij2iwk4rbuVS7Wxgf3ktfRMjvfL6hYeq+fWb61iyvZQz+yXw8FVDSY3VJbjKOlc/uRR7UBCv3HzqR0N0ZHEKuoXY+NHZfXngvY0s2V7CxH5tL5lcur2EH760ivBQG/N+PJFBvY79Jlxb38gXW4t5f91+3vlqH6/k7iE5OoxvDevJV3sOU1JVxxu3TNBE0YHCQ+xeSxQAqbHhvHzjOObk7uah9zcx5bHFPDhtKFeM1BqXyhqVtQ0++8KiyaKF68al8eKyAn795jo++NkkosKCj2szb00hv3pjHZkJETx//Vh6xXQ7rk1YsI0pg5OZMjiZamcDH28qYv5X+3hhWQH1jYa/zxihBf8CkIhw3bh0zspO5I65a7l97leIwOUjNGEo33M4fXP9bdDLqh4nLNjGo9NHsO9wDfe/u/G4519avovbX1tLTnocr98yodVE0VJ4iJ3Lhvfi2Zk55N1zPu/ddqb+cQlwqbHhvHjDWMZlxnHH3LV8+PUBq0NSXZCjrtEnJ+SBJotWjU6P5cdn9+P1VYXH/BGYs2I3//f215w/sAfP3zCG7t2OH3V4EhMe4nGVjwoMYcE2np05hmGp3bntldV8vqXI6pBUF1PlwwluTRZt+Ol5WQxJieY389ZTVFnLq7m7+c289Zw7oAePXzeqy6xeUicWGWrn+VljyeoRxQ9fWsXSHSVWh6S6iHr39W00WVgsxB7EY9NH4Khr4NpnVnD3vPWclZ3IE5ooVAvdw4N56caxpMWFM3N2Lq/l7bY6JNUF+LI8OWiyOKGspCjuumgA24uqOLNfAk99bzRhPqjBogJPfGQob9wygTP6xPPrN9fz+3c30NDYZHVYqhPzZcVZ0NVQHs0cn0H/pChGpcdqolAn1D08mOdmjeGhBZuZvSSfbQer+M3FA7+pOaUjUtWRqnw8stBk4UFQkDDhBOdbKNWc3RbEvZcOYkByFPe8vZ6L//HlN891C7YxNKU75w/qwfkDk+jjxXNCVOd39DCUb76EaLJQygumj+lNTkYsm/ZXcqjayeFqJyVVTlbkl/HQgs08tGAzfRMj+N4Z6Vx3RnqrJWGUOpGqOlflZz0MpVSA65MY2eroofBQNZ9sKuLdtfu4792NvLh8F/dcPJBzB/TQelOq3XSCW6lOLjU2nJkTMnj9lvE8+/0cMK7rZnzvP7nsKnV4fgGl8P0EtyYLpSwiIpw/KIkPfz6Ze781iHWFh5n2xFJW7z5kdWgqADg0WSjVtYTYg7jhzEzeufVMosLsXPvMchZu0PIh6sT0MJRSXVRmQgRv/mgC/ZOjueXlVbywtMDqkJQfq6xrIMQW5LMLpmmyUMqPJESG8upNZ3D+wCR+N38Dd7+1jmpng9VhKT/kqGvw2bJZ0GShlN/pFmLjye+O5kdn9+XVvD1865//4+u95VaHpfyMLyvOgiYLpfySLUj49dQB/PfGcTjqGpj2xBKeWbyTpqbOcWVLdfp8WXEWNFko5dcm9Evgw59N5pz+PXhwwSb+/sk2q0NSfsKX198GTRZK+b3YiBCe+t5opo1M4V+fbWeNLq1VHJmz0GShlGpGRPj95YNJjg7jjrlrddJb6WEopVTrosOCeWT6cApKHTz4/iarw1EWq9LVUEqptpzRJ56bJ/Xhvyt28+nmg1aHoyzk69VQWkhQqQBzx4XZfLG1mF+9sY7zByZxsKKWgxV1VNbV8/i1oxiWGmN1iMrLjDE4nHoYSil1AqF2G3+bMYIQWxAfbyqiqLKOnt3DqK5r5P/e/lqX13YB1c5GjPFdXSjQkYVSAWlAcjRL7z7vmG1vrS7kjrlreXN1IVfn9LYoMuULvq4LBTqyUKrTuGJECiPTYvjTh1uorK23OhzlRb4uTw6aLJTqNIKChPsuHUxJVR3/+nS71eEoL3K4r5KnIwul1CkZ3juGq0enMntJPjuLq6wOR3lJZZ1r5Nhpls6KyFQR2SIi20XkrjbaTBeRjSKyQUTmNNv+Z/e2TSLyD9HrTSrVLndO7U+o3cYf9FyMTsvh4+tvgxeThYjYgMeBi4BBwDUiMqhFmyzgbmCiMWYw8HP39gnARGAYMAQYA5zlrViV6kx6RIXx0/P68enmIuatKbQ6HOUFnW2Ceyyw3Riz0xjjBF4FLm/R5ibgcWPMIQBjTJF7uwHCgBAgFAgG9Awkpdrp+omZnNEnjl+/uZ61ew5bHY7qYEcmuKM6SbJIAfY0e1zo3tZcNpAtIktEZLmITAUwxiwDPgP2u28LjTE6plaqnYJtQTxx3Wh6RIVy80srOVhRa3VIqgN1tpFFa3MMLc8WsgNZwNnANcCzIhIjIv2AgUAqrgRzrohMPu4NRG4WkZUisrK4uLhDg1cq0MVFhPDszBwqaxu4+aVV1NY3Wh2S6iCOugZEIDykc0xwFwLNzwxKBfa10uYdY0y9MSYf2IIreUwDlhtjqowxVcAHwBkt38AY87QxJscYk5OYmOiVTigVyAYkR/PYd0awds9h7n5rPcbo2d2dQVVdIxEhdny57sebySIPyBKRTBEJAWYA81u0eRs4B0BEEnAdltoJ7AbOEhG7iATjmtzWw1BKnYIpg5P5xQXZzFuzl9fy9njeQfm9qrp6ny6bBS8mC2NMA3ArsBDXH/q5xpgNInK/iFzmbrYQKBWRjbjmKO40xpQCbwA7gPXAWmCtMeZdb8WqVGf3k3P6MaFvPA+8t5HdpdVWh6NOk68rzgJIZxmW5uTkmJUrV1odhlJ+a+/hGqY+tpiBPaN55eYzsAXpqUuBaubsXA5VO5l/65mn/VoissoYk+OpnZ7BrVQXkRLTjfsuG0xuQRnPfrnT6nDUafD19bdBk4VSXcqVo1KYOjiZRz7ayuYDFVaHo05RlY+vvw2aLJTqUkSEB6cNIbpbMLe/tpYapy6nDUS+vvARaLJQqsuJjwzlz98eyuYDFdz4Qp4mjADkmuDuJKuhlFL+69wBSTxy9XCW7SzVhBGAqmr1MJRSykeuHJXKo9OHs3xnKTc8n0e1s8HqkFQ7OBuacDY2ERmiyUIp5SPTRqby6PQRrMh3JQwtCeL/jtSFigzTZKGU8qErRqbw6PQRLN9Zxj3zvtaSIH6uyoIiguAq5KeU6uKuGJlCQamDv328jYE9o/jBpD5Wh6Ta4HD6/vrboCMLpZTbT8/N4qIhyTy0YBNfbNUqzv7KivLkoMlCKeUWFCT89erhZCdFceuc1XoNbz9V9c0lVX27dFYPQymlvhERaueZ7+dw+eNL+PaTy+gRFUpDk6GxydArJozZs8YQavftHyl1rKpaHVkopfxA77hwnp2ZQ056LL3jwslOiqRfj0iWbC/VEud+4JvDUD5eOqsjC6XUcUalxfL0948WIjXG8J2nlvOvT7czPac3YcE6urDKN9ff1qWzSil/IyLccWE2RZV1vLx8l9XhdGk6wa2U8mtn9InnzH4J/PvzHd/8wVK+V+VsIMQeRLDNt3++NVkopdrtjguzKXU4eWFZgdWhdFlWXMsCNFkopU7CqLRYzh3Qg6e+2ElFbb3V4XRJVlScBU0WSqmTdMcF2ZTX1DP7f/lWh9IlVdTUExUa7PP31WShlDopQ1K6M3VwMv/5Ml9Lm1ugxOEkPjLE5++ryUIpddJmTcygsq6BhRsOWB1Kl1PmqCM+QpOFUioAjM2Io3dcN15fpSfp+VpZlZO4iFCfv68mC6XUSQsKEr49qjdLd5RSeKja6nC6jNr6RhzORj0MpZQKHFeOSsEYmLd6r9WhdBmlDieAHoZSSgWO3nHhjO8TzxurC/WCST5SVuVKFnGaLJRSgeTqnFR2lVaTV3DI6lC6hBJHHQDxkTpnoZQKIFOHJBMRYuMNnej2iSMjCz0MpZQKKOEhdi4Z1pP31+2n2qn1orytzD1nEacT3EqpQHN1Tm8czkY+WK/nXHhbiaOOEFsQUVobSikVaHLSY0mPD+eNVYVWh9Lpuc6xCEFEfP7emiyUUqdFRLh8eC9W5JdSXq3FBb2pzOG0ZCUUtDNZiEhfEQl13z9bRH4qIjHeDU0pFSjO6p9Ik4ElO0qsDqVTs6ouFLR/ZPEm0Cgi/YD/AJnAHK9FpZQKKMNTY4gKs7N4a7HVoXRqVtWFgvYniyZjTAMwDfibMeZ2oKf3wlJKBRK7LYgz+yWweGuxnqDnRVbVhYL2J4t6EbkGmAm8597m+4LqSim/NSkrkX3ltewodlgdSqdkZV0oaH+yuB4YDzxojMkXkUzgZe+FpZQKNJOyEgD0UJSXWFkXCtqZLIwxG40xPzXGvCIisUCUMeZhT/uJyFQR2SIi20XkrjbaTBeRjSKyQUTmNNueJiIficgm9/MZ7eyTUsoCvePC6ZMQweJtmiy8wcq6UND+1VCfi0i0iMQBa4HnRORRD/vYgMeBi4BBwDUiMqhFmyzgbmCiMWYw8PNmT78I/MUYMxAYCxS1s09KKYtMzk5k+c5Sauv1Cnodzcq6UND+w1DdjTEVwJXAc8aY0cD5HvYZC2w3xuw0xjiBV4HLW7S5CXjcGHMIwBhTBOBOKnZjzCL39ipjjBbNV8rPTc5OoLa+iVW7tLBgR7OyLhS0P1nYRaQnMJ2jE9yepADNq4sVurc1lw1ki8gSEVkuIlObbT8sIm+JyBoR+Yt7pHIMEblZRFaKyMriYh36KmW1cZnxBNtE5y28wMq6UND+ZHE/sBDYYYzJE5E+wDYP+7R2PnrLNXV2IAs4G7gGeNZ9sp8dmAT8EhgD9AFmHfdixjxtjMkxxuQkJia2sytKKW+JCLWTkx7HF5osOpyVdaGg/RPcrxtjhhljfuR+vNMYc5WH3QqB3s0epwL7WmnzjjGm3hiTD2zBlTwKgTXu92kA3gZGtSdWpZS1JmcnsvlAJUUVtVaH0qlYWRcK2j/BnSoi80SkSEQOisibIpLqYbc8IEtEMkUkBJgBzG/R5m3gHPd7JOA6/LTTvW+siBwZLpwLbGxfl5RSVjqyhPbLbVr6oyNZWRcK2n8Y6jlcf+h74Zp3eNe9rU3uEcGtuA5fbQLmGmM2iMj9InKZu9lCoFRENgKfAXcaY0qNMY24DkF9IiLrcR3SeubkuqaUssKgntEkRIboEtoOZmVdKHDNDbRHojGmeXJ4XkR+3mZrN2PMAmBBi233NrtvgDvct5b7LgKGtTM+pZSfCAoSJmcl8snmIpwNTYTYtbh1Ryhz1JEZH27Z+7f3UywRke+KiM19+y5Q6s3AlFKB69LhvSivqefzLXp6VEexsi4UtD9Z3IBr2ewBYD/wbVwlQJRS6jiTshKIjwhh3pq9VofSKVhdFwravxpqtzHmMmNMojGmhzHmClwn6Cml1HHstiAuHd6LTzYVUV6jF0Q6XVbXhYLTu1LecfMMSil1xJWjUnA2NrFg/X6rQwl4VteFgtNLFtYs9lVKBYShKd3pkxjBvNV6KOp0Ha0LFZjJQq9wopRqk4hw5cgUcgvK2FOmpd1Ox9G6UH46wS0ilSJS0cqtEtc5F0op1abLR7jKwb3zlY4uTofVdaHAQ7IwxkQZY6JbuUUZY6wpUKKUChi948IZmxHHvDV79XKrp6HEUUewTSyrCwWndxhKKaU8umJkCjuKHazfW251KAGrrMpJfESoZXWhQJOFUsrLLhnakxBbEG/pRPcps7ouFGiyUEp5WffwYM7qn8iijQetDiVgWV0XCjRZKKV8YFxmHHsP12jZ8lNU5qiz9IQ80GShlPKBUemxAKzerZdbPRVW14UCTRZKKR8Y3CuaEFsQq3cftjqUgOMPdaFAk4VSygdC7TaGpESzapeOLE6WP9SFAk0WSikfGZ0ey/q95TgbmqwOJaD4Q10o0GShlPKRUWmxOBua2LBPz7c4Gf5QFwo0WSilfOToJLfOW5wMf6gLBZoslFI+khQdRkpMN1brvMVJ8Ye6UKDJQinlQ6PSY3X57Enyh7pQoMlCKeVDo9Ji2F9ey77DNVaHEjBc51iEWFoXCjRZKKV8aFSanpx3sooq6+gRFWZ1GJoslFK+M6hXNGHBQazepZPc7VVUWUdilLWT26DJQinlQ8G2IIalxOjI4iQUV9bRQ5OFUqqrGZkew4Z95dTWN1odit9raGyi1KHJQinVBY1Oi6W+0fC1XgzJo1KHE2MgMVrnLJRSXYxWoG2/ogrX2ds6slBKdTkJkaGkxYWTm6/JwpOiStf1PzRZKKW6pCmDk/hk80GtE+VBcaVrZKGroZRSXdKt52QR0y2Y38/fiDHG6nD8VpEmC6VUV9Y9PJhfTulPbkEZ763bb3U4fquospaY8GBC7TarQ9FkoZSyxowxaQzqGc0fF2yixqnLaFtTVOEfy2ZBk4VSyiK2IOG+ywazr7yWf3+xw+pw/JK/lPoATRZKKQuNzYzj0uG9eOqLHewpq7Y6HL9T7CelPkCThVLKYndfNAAReGzRVqtD8SvGGL8p9QGaLJRSFusV042Lh/bki63FujKqmfKaepyNTV1jZCEiU0Vki4hsF5G72mgzXUQ2isgGEZnT4rloEdkrIv/yZpxKKWuNy4yj1OFkR7HD6lD8xpFlsz38oNQHgNcuvSQiNuBx4AKgEMgTkfnGmI3N2mQBdwMTjTGHRKRHi5d5APjCWzEqpfzDmIw4AHLzy+jXI9LiaPyDP5X6AO+OLMYC240xO40xTuBV4PIWbW4CHjfGHAIwxhQdeUJERgNJwEdejFEp5QcyEyJIiAwlN7/U6lD8RnGV/5T6AO8mixRgT7PHhe5tzWUD2SKyRESWi8hUABEJAh4B7jzRG4jIzSKyUkRWFhcXd2DoSilfEhHGZcaRV6D1oo44MrLoCnMWrV0wtuXslR3IAs4GrgGeFZEY4MfAAmPMHk7AGPO0MSbHGJOTmJjYASErpawyJiOWvYdrKDykS2jBNWfRLdhGZKjXZgtOijejKAR6N3ucCuxrpc1yY0w9kC8iW3Alj/HAJBH5MRAJhIhIlTGm1UlypVTgG5sZD0BeQRmpseEWR2O9oso6ekSHItLa927f8+bIIg/IEpFMEQkBZgDzW7R5GzgHQEQScB2W2mmMuc4Yk2aMyQB+CbyoiUKpzq1/chRRYXZy88usDsUvFFXU+s18BXgxWRhjGoBbgYXAJmCuMWaDiNwvIpe5my0ESkVkI/AZcKcxRme4lOqCbEHCmIw4VmiyAKC4yn9KfYB3D0NhjFkALGix7d5m9w1wh/vW1ms8DzzvnQiVUv5kbGYcn24uoqSqjoRI//lWbYXiijomZ/nP/4Gewa2U8htHzrfI6+KjixpnI5V1DX6zEgo0WSil/MjQlO6EBQeRW9C1k4U/XU71CE0WSim/EWIPYlRabJef5Pa3Uh+gyUIp5WfGZMSxcX8FFbX1VodimSPX3taRhVJKtWFcZhzGwKpdXfds7qIKPQyllFInNDItFnuQdOlDUUWVddiDhNjwEKtD+YYmC6WUX+kWYmNE7xgWb+269d6KKl1Lh4OC/OPsbdBkoZTyQxcMSmLDvoouWydnF5XmAAAR3ElEQVSq2F3qw59oslBK+Z0pg5MB+GjDQYsjsUaRH11O9QhNFkopv5OREEH/pCgWbjhgdSiWKK6sJdGPSn2AJgullJ+6cHASeQVllDmcVofiUw2NTZQ6nDqyUEqp9pgyOJkmAx9v6lqHokqqnBjjPxc9OkKThVLKLw3uFU1KTDc+6mKHovzxhDzQZKGU8lMiwgWDkli8rQRHXYPV4fjMN3Wh/KjUB2iyUEr5sSmDk3E2NHWpcy6KdGShlFInZ0xGLLHhwV1qVVRRhStZ+Nv1PDRZKKX8lt0WxHkDk/hkcxHOhiarw/GJAxW1xEeEEGL3rz/P/hWNUkq1MGVwMpW1DazI7xpXXN5d5qB3XLjVYRxHk4VSyq9NykqgW7CN+V/tszoUnygoqSYjXpOFUkqdlLBgG1fnpDJvzV52lTqsDser6hoa2VdeQ3p8hNWhHEeThVLK7916Tj/sNuGxRVutDsWr9pTVYAxkJOjIQimlTlqP6DBmTcjknbX72HygwupwvGZ3mWvkpCMLpZQ6Rbec1YfIUDuPfNR5RxcFJa6S7BmaLJRS6tTEhIfww8l9WLTxIKt3d85Lru4qdRAVZic2PNjqUI6jyUIpFTCun5hJQmQIf124pc02h6udnPvXz3ktb7cPI+sYBaXVpMeHI+I/V8g7QpOFUipgRITa+ck5/Vi6o5T/bStptc2ji7ays8TBk1/spKnJ+DjC07Or1OGX8xWgyUIpFWCuHZdGamw37nprHYerj73WxcZ9Fby8fBd9EyPIL3GwdEfgnMhX39hE4aEavzzHAjRZKKUCTKjdxj+vGcnBilp+MXftN6MHYwz3zd9A927BvHLTGcSGB/Py8l0WR9t++w7X0NBkdGShlFIdZWRaLL+9ZBCfbC7i31/sAGD+2n3kFpTxq6kD6BEdxvSc3izadJAD5bUWR9s+BaX+uxIKNFkopQLU98enc+nwXjzy0RY+3niQhxZsYmhKd6bn9AZch6samwyv5e2xONL2OXJ2uh6GUkqpDiQiPHzlUDITIvjBiys5WFHH7y8fjC3ItZIoPT6CydmJvJK7m4ZG/69Yu6u0mm7BNr+7nOoRmiyUUgErItTOk98dTUSIje/k9GZUWuwxz393XBoHKmr5ZHOR12IwxpBfcvo1q1wrofxz2SxoslBKBbispCiW3nUef7xy6HHPnTugBz27h3l1ovv99fs556+f87t3vj6tEcyRcyz8lSYLpVTA6x4eTFDQ8d/I7bYgZoxJ48ttJRR0wLf/1ny5tQRbkPDCsl3c9OJKqk7heuGNTYbdpdV+O7kNmiyUUp3cjLG9CbYJTy3e6ZXXzy0o45z+ifzhiiEs3lbC1U8uY9/hmpN6jQMVtTgbm/x22Sx4OVmIyFQR2SIi20XkrjbaTBeRjSKyQUTmuLeNEJFl7m3rROQ73oxTKdV5JUWHce3YNOau3MOO4iqP7Y1xfcs3xvPZ30WVteSXOBibGcd3z0jnuVljKCyrZtoTS447YfBEdpX490oo8GKyEBEb8DhwETAIuEZEBrVokwXcDUw0xgwGfu5+qhr4vnvbVOBvIhLjrViVUp3bredmEWoP4pGP2q4pVVvfyGt5u5n6ty+Z/JfPeCXX85LbvHxXQcMxGXEATM5O5MUbx3Kwoo7/rmh/baoj51ikJ3TNkcVYYLsxZqcxxgm8Clzeos1NwOPGmEMAxpgi979bjTHb3Pf3AUVAohdjVUp1YolRodw0qQ8L1h9g7Z7DxzznbGjibx9vZeLDn/LrN9cjAoN6RvOnDzdT5jjx6CCvoIxuwTaGpHT/ZtvItFjO7JfAC0sLcDa0b8J7V6mDEFsQydFhJ985H/FmskgBmqfmQve25rKBbBFZIiLLRWRqyxcRkbFACLDDa5EqpTq9myb3IT4ihIc/2PzNIaaK2npmPZfL3z7exsi0GObcNI4PfjaJv88YgaOugT9/uPmEr7kiv4zR6bEE2479U/qDSZkUVdbx7tr2XTd8V2k1veO6fXOOiD/yZrJordctDwLagSzgbOAa4Nnmh5tEpCfwEnC9Mea4FC0iN4vIShFZWVxc3GGBK6U6n8hQO7ee249lO0tZvK2EA+W1TH9yGbn5ZTxy9XCenTmGCX0TEBGykqK4fmIGr63cw5o2rp1RXlPP5gMV3xyCau6s7ESykyJ55sud7Zr7KCh1+PVKKPBusigEejd7nAq0TLOFwDvGmHpjTD6wBVfyQESigfeB3xpjlrf2BsaYp40xOcaYnMREPUqllDqxa8el0TuuGw+8t5Ern1jCnrJqZs8aw1WjU49r+7Pzs0mMDOXedzbQ2Eqp81W7yjAGxmYenyxEhB+c2YfNByo9Vr41xrCrtNqvV0KBd5NFHpAlIpkiEgLMAOa3aPM2cA6AiCTgOiy1091+HvCiMeZ1L8aolOpCQu02fnFBf7YXVVHfZJh7y3gmZ7f+RTMy1M49lwxk/d5yXm3lQkor8ssItgkj01pfe3PZiF4kRIbwzJdHl+waY3gldze3vbLmm1pQxZV11NQ3kpHgvyuhwHUYyCuMMQ0iciuwELABs40xG0TkfmClMWa++7kLRWQj0AjcaYwpFZHvApOBeBGZ5X7JWcaYr7wVr1Kqa7hseC+cDU1MzEogJaabx7ZzVuzmzx9uYergZOIjj9ZtyssvY1hqDGHBtlb3DQu28f3xGTy6aCvbDlYSFRbMr95cx+KtxdiChI83HuTuiweQnRQF4PcjC2nP8bRAkJOTY1auXGl1GEqpTmbrwUou+ceXXDAoicevHYWIUONsZOh9C/nBpD7cddGANvctcziZ8PAnDOwZzY6iKuobDXdfPIDzBiZx91vrWby1mMSoUIor6/jizrMtSRgissoYk+OpnZ7BrZRSJ5CdFMXPz89mwfoDzHevblqz+xANTYZxrcxXNBcXEcJVo1JZs/swWUlRLPjZJL4/PoOUmG68cP0YHr5yKDXORkLtQfTyMMqxmtcOQymlVGfxw8l9+HjTQe59ZwNn9Iknt6AMERiVHutx37suGsCkrEQuGJR0zNJYEWHG2DTO6p9IcWXdcctv/Y1/R6eUUn7Abgvi0ekjqGto5NdvrmPFzjIGJkfTvVuwx32jwoKZOiS5zXMoenbvxrBU/y9QoclCKaXaITMhgrsvGsjnW4pZtrO01SWznZkmC6WUaqfvnZHOhL7xQOvnV3RmmiyUUqqdgoKER6ePYNaEDM7u37VOBNYJbqWUOgnJ3cO477LBVofhczqyUEop5ZEmC6WUUh5pslBKKeWRJgullFIeabJQSinlkSYLpZRSHmmyUEop5ZEmC6WUUh51mutZiEgxsKuVp7oD5R62NX/c2v3m2xKAklMIsbU42tumI/rQ/P6p9uFEMbanzYli9vS45WfhL31obZu/fBYnev5UPwt//nlqbZv+bnuWbozxfDq6MaZT34CnPW1r/ri1+y22reyoONrbpiP60KI/p9SHju7HyTxu+Vn4Sx/8+bM40fOn+ln488/TqXwW+rvd/ltXOAz1bju2vevhfmuv0RFxtLdNR/ShvTF40pH9OJnH+lm0L5b2Pn+qn4U//zy1tk1/tztIpzkM5SsistK04xKE/kz74D86Qz86Qx+gc/TDm33oCiOLjva01QF0AO2D/+gM/egMfYDO0Q+v9UFHFkoppTzSkYVSSimPumyyEJHZIlIkIl+fwr6jRWS9iGwXkX+IiDR77jYR2SIiG0Tkzx0bdauxdHg/ROQ+EdkrIl+5bxd3fOTHxOGVz8L9/C9FxIhIQsdF3GYs3vgsHhCRde7P4SMR6dXxkR8Thzf68BcR2ezuxzwR8eoFp73Uh6vdv9NNIuK1eY3Tib2N15spItvct5nNtp/w96ZV3lpm5e83YDIwCvj6FPbNBcYDAnwAXOTefg7wMRDqftwjQPtxH/DLQP4s3M/1BhbiOv8mIRD7AUQ3a/NT4MkA7MOFgN19/0/AnwKwDwOB/sDnQI6/xe6OK6PFtjhgp/vfWPf92BP180S3LjuyMMYsBsqabxORviLyoYisEpEvRWRAy/1EpCeuX+BlxvW//iJwhfvpHwEPG2Pq3O9R5N1eeK0fPuXFPjwG/ArwycScN/phjKlo1jQCL/fFS334yBjT4G66HEgNwD5sMsZs8WbcpxN7G6YAi4wxZcaYQ8AiYOqp/u532WTRhqeB24wxo4FfAk+00iYFKGz2uNC9DSAbmCQiK0TkCxEZ49Vo23a6/QC41X3YYLaIxHov1DadVh9E5DJgrzFmrbcD9eC0PwsReVBE9gDXAfd6Mda2dMTP0xE34Pom62sd2Qdfa0/srUkB9jR7fKQ/p9RPvQa3m4hEAhOA15sdvgttrWkr245827PjGu6dAYwB5opIH3f29okO6se/gQfcjx8AHsH1S+4Tp9sHEQkH7sF1+MMyHfRZYIy5B7hHRO4GbgV+18Ghtqmj+uB+rXuABuC/HRmjJx3ZB187Uewicj3wM/e2fsACEXEC+caYabTdn1PqpyaLo4KAw8aYEc03iogNWOV+OB/XH9Lmw+hUYJ/7fiHwljs55IpIE65aLcXeDLyF0+6HMeZgs/2eAd7zZsCtON0+9AUygbXuX7BUYLWIjDXGHPBy7M11xM9Uc3OA9/FhsqCD+uCeXP0WcJ4vvzy5dfTn4Eutxg5gjHkOeA5ARD4HZhljCpo1KQTObvY4FdfcRiGn0k9vTdQEwg3IoNlEErAUuNp9X4DhbeyXh2v0cGRy6GL39luA+933s3ENASUA+9GzWZvbgVcDrQ8t2hTggwluL30WWc3a3Aa8EYB9mApsBBJ98Rl48+cJL09wn2rstD3BnY/raEes+35ce/rZaly++vD87Qa8AuwH6nFl2htxfRv9EFjr/uG+t419c4CvgR3Avzh6cmMI8LL7udXAuQHaj5eA9cA6XN+4egZaH1q0KcA3q6G88Vm86d6+Dlf9n5QA7MN2XF+cvnLfvL2iyxt9mOZ+rTrgILDQn2KnlWTh3n6D+/9/O3D9yfzetLzpGdxKKaU80tVQSimlPNJkoZRSyiNNFkoppTzSZKGUUsojTRZKKaU80mShOjURqfLx+z0rIoM66LUaxVVt9msReddTtVYRiRGRH3fEeyvVki6dVZ2aiFQZYyI78PXs5mhRPK9qHruIvABsNcY8eIL2GcB7xpghvohPdS06slBdjogkisibIpLnvk10bx8rIktFZI373/7u7bNE5HUReRf4SETOFpHPReQNcV2n4b9Hrgfg3p7jvl/lLgK4VkSWi0iSe3tf9+M8Ebm/naOfZRwtkhgpIp+IyGpxXZPgcnebh4G+7tHIX9xt73S/zzoR+X0H/jeqLkaTheqK/g48ZowZA1wFPOvevhmYbIwZiau660PN9hkPzDTGnOt+PBL4OTAI6ANMbOV9IoDlxpjhwGLgpmbv/3f3+3usyeOuYXQerrPpAWqBacaYUbiuofKIO1ndBewwxowwxtwpIhcCWcBYYAQwWkQme3o/pVqjhQRVV3Q+MKhZFc9oEYkCugMviEgWriqcwc32WWSMaX6dgVxjTCGAiHyFq57P/1q8j5OjRRhXARe474/n6PUD5gB/bSPObs1eexWu6xGAq57PQ+4//E24RhxJrex/ofu2xv04ElfyWNzG+ynVJk0WqisKAsYbY2qabxSRfwKfGWOmuY//f97saUeL16hrdr+R1n+X6s3RScG22pxIjTFmhIh0x5V0fgL8A9d1LRKB0caYehEpAMJa2V+APxpjnjrJ91XqOHoYSnVFH+G6LgQAInKk/HN3YK/7/iwvvv9yXIe/AGZ4amyMKcd1SdVfikgwrjiL3IniHCDd3bQSiGq260LgBvc1ERCRFBHp0UF9UF2MJgvV2YWLSGGz2x24/vDmuCd9N+IqLQ/wZ+CPIrIEsHkxpp8Dd4hILtATKPe0gzFmDa6qozNwXTwoR0RW4hplbHa3KQWWuJfa/sUY8xGuw1zLRGQ98AbHJhOl2k2XzirlY+4r+dUYY4yIzACuMcZc7mk/paykcxZK+d5o4F/uFUyH8eEla5U6VTqyUEop5ZHOWSillPJIk4VSSimPNFkopZTySJOFUkopjzRZKKWU8kiThVJKKY/+HyGAQu1R5J7yAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"learn.recorder.plot()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total time: 00:38\n",
"epoch train_loss valid_loss accuracy\n",
"1 0.253736 0.253902 0.896600 (00:08)\n",
"2 0.193293 0.231615 0.909400 (00:07)\n",
"3 0.144753 0.227158 0.913200 (00:07)\n",
"4 0.092080 0.244881 0.914300 (00:07)\n",
"5 0.067578 0.248254 0.915600 (00:07)\n",
"\n"
]
}
],
"source": [
"learn.fit_one_cycle(5, 1e-2, wd=1e-6)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TabularModel(\n",
" (embeds): ModuleList(\n",
" (0): Embedding(4, 3)\n",
" (1): Embedding(33283, 50)\n",
" (2): Embedding(5, 3)\n",
" (3): Embedding(5, 3)\n",
" (4): Embedding(8, 5)\n",
" (5): Embedding(3, 2)\n",
" (6): Embedding(3481, 50)\n",
" (7): Embedding(286, 50)\n",
" (8): Embedding(26, 14)\n",
" (9): Embedding(10492, 50)\n",
" (10): Embedding(304, 50)\n",
" (11): Embedding(30, 16)\n",
" (12): Embedding(1461, 50)\n",
" (13): Embedding(570, 50)\n",
" (14): Embedding(300, 50)\n",
" (15): Embedding(3, 2)\n",
" (16): Embedding(3, 2)\n",
" )\n",
" (emb_drop): Dropout(p=0.05)\n",
" (bn_cont): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=462, out_features=64, bias=True)\n",
" (1): ReLU(inplace)\n",
" (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (3): Dropout(p=0.5)\n",
" (4): Linear(in_features=64, out_features=2, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:fastai]",
"language": "python",
"name": "conda-env-fastai-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@joshfp
Copy link
Author

joshfp commented Jun 29, 2019

make gist public

@kontrabas380
Copy link

kontrabas380 commented Jan 29, 2020

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

@ascientist
Copy link

Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'

Best regards

Same problem here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment