Last active
October 10, 2020 16:02
-
-
Save joshfp/b62b76eae95e6863cb511997b5a63118 to your computer and use it in GitHub Desktop.
Fast.ai p1v1: class 4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# NLP model to predict from title (ULMFiT)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from fastai import *\n", | |
"from fastai.text import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pretrained_path = Path('~/datasets/wikimedia').expanduser()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df = pd.read_feather('tabular-df')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"columns = ['title', 'target']\n", | |
"df = df[columns]\n", | |
"N = -10000\n", | |
"train_df = df[:N]\n", | |
"valid_df = df[N:]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Languaje model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_lm = TextLMDataBunch.from_df('.', train_df, valid_df, tokenizer=Tokenizer(lang='es'), text_cols='title')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# pretrained model (spanish wiki)\n", | |
"pretraind_fnames = (pretrained_path/'models/weights-6', pretrained_path/'itos')\n", | |
"learn = language_model_learner(data_lm, drop_mult=0.3, pretrained_fnames=pretraind_fnames)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.lr_find()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"learn.recorder.plot()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 00:23\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 5.177172 4.802299 0.309109 (00:23)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.fit_one_cycle(1, 3e-2, moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('fit-head')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.load('fit-head');" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 02:52\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 4.552735 4.594989 0.326132 (00:27)\n", | |
"2 4.448153 5.098676 0.257221 (00:28)\n", | |
"3 4.126873 4.391115 0.337241 (00:29)\n", | |
"4 3.803120 4.354942 0.343525 (00:29)\n", | |
"5 3.485168 4.389339 0.344239 (00:29)\n", | |
"6 3.294815 4.425426 0.342907 (00:29)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.unfreeze()\n", | |
"learn.fit_one_cycle(6, 3e-3, moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('fine-tuned')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.load('fine-tuned');" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 00:00\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"'samsung galaxy tab , lo mejor'" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"learn.predict('samsung', 5, temperature=1.1, min_p=0.001)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save_encoder('fine-tuned-enc')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pickle.dump(data_lm.vocab, open('itos', 'wb'))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Title classifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab = pickle.load(open('itos', 'rb'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data_clf = TextClasDataBunch.from_df('.', train_df, valid_df, df, tokenizer=Tokenizer(lang='es'), \n", | |
" vocab=vocab, text_cols='title', label_cols='condition')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn = text_classifier_learner(data_clf, max_len=100, drop_mult=0.5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.load_encoder('fine-tuned-enc')\n", | |
"learn.freeze()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.lr_find()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"learn.recorder.plot(skip_end=8)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 00:25\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 0.502706 0.497329 0.769300 (00:25)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.fit_one_cycle(1, 3e-3, moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('first')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 00:29\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 0.478203 0.451679 0.795000 (00:29)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.freeze_to(-2)\n", | |
"learn.fit_one_cycle(1, slice(1e-3/(2.6**4), 1e-3), moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('second')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 00:42\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 0.472015 0.437852 0.802800 (00:42)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.freeze_to(-3)\n", | |
"learn.fit_one_cycle(1, slice(5e-4/(2.6**4), 5e-4), moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('third')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total time: 10:32\n", | |
"epoch train_loss valid_loss accuracy\n", | |
"1 0.447866 0.431794 0.805600 (01:02)\n", | |
"2 0.434827 0.429153 0.805000 (01:02)\n", | |
"3 0.433098 0.418664 0.811900 (01:04)\n", | |
"4 0.440408 0.415776 0.811500 (01:03)\n", | |
"5 0.440411 0.414932 0.812300 (01:02)\n", | |
"6 0.431286 0.417159 0.811300 (01:02)\n", | |
"7 0.421277 0.405521 0.819000 (01:02)\n", | |
"8 0.422464 0.410475 0.816700 (01:05)\n", | |
"9 0.415283 0.423141 0.808500 (01:02)\n", | |
"10 0.433402 0.412326 0.815200 (01:03)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"learn.unfreeze()\n", | |
"learn.fit_one_cycle(10, slice(1e-4/(2.6**4), 1e-4), moms=(0.8, 0.7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.save('final')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Predict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 142, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"learn.load('final');" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <progress value='0' max='1', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 0.00% [0/1 00:00<00:00]\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"('used', tensor(1), tensor([0.0055, 0.9945]))" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"learn.predict('iphone 7 como nuevo') # 'iphone 7 as new'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"preds, _ = learn.get_preds(DatasetType.Test, ordered=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>title_isnew_prob</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0.972612</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.840828</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>0.000209</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>0.890611</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>0.145972</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" title_isnew_prob\n", | |
"0 0.972612\n", | |
"1 0.840828\n", | |
"2 0.000209\n", | |
"3 0.890611\n", | |
"4 0.145972" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"preds_df = pd.DataFrame({'title_isnew_prob': preds[:,0]})\n", | |
"preds_df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# save predictions\n", | |
"preds_df.to_feather(PATH/'title-df')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:fastai]", | |
"language": "python", | |
"name": "conda-env-fastai-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
}, | |
"varInspector": { | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"delete_cmd_postfix": "", | |
"delete_cmd_prefix": "del ", | |
"library": "var_list.py", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"delete_cmd_postfix": ") ", | |
"delete_cmd_prefix": "rm(", | |
"library": "var_list.r", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
], | |
"window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'
Best regards
Hello! You've done nice job and I've got a question. When you finish training this model, how can you predict one example? It's not working with .predict(example).
I've done this with two AWD_LSTM networks, but in the end I've met an issue with this error while making prediction:
AttributeError: 'ConcatDataset' object has no attribute 'set_item'Best regards
Same problem here
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
make gist public