Created
April 26, 2018 06:31
-
-
Save kevinbird15/4b832ec5c079473188ef3b51f7702b5d to your computer and use it in GitHub Desktop.
Counting Example for Language Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import fastai\nfrom fastai import learner\nfrom fastai import dataset\nfrom fastai import model\nfrom pathlib import Path\nfrom fastai.text import *\n\nimport pandas as pd\nimport numpy as np\nimport spacy\nimport json\nimport re\nimport html", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Creating Data\nUsually I see people use completely random data when they don't have a dataset to show a concept. Instead, I'm going to use a counting dataset that starts at a random number and then counts up 10, wrapping around from \"nine\" to \"zero\". " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "numbers = [\"zero\", \"one\", \"two\", \"three\", \"four\", \"five\", \"six\", \"seven\", \"eight\", \"nine\"]", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "myData = pd.DataFrame()", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def DataGenerator():\n numlist = \"\"\n starting_num = random.randint(0,9)\n for i in range(10):\n if i==0:\n numlist = str(numbers[(starting_num+i)%10])\n else:\n numlist = numlist + \" \" + str(numbers[(starting_num+i)%10])\n return numlist", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "DataGenerator()", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 5, | |
"data": { | |
"text/plain": "'eight nine zero one two three four five six seven'" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "for i in range(1000):\n myData = myData.append(pd.Series(DataGenerator()), ignore_index=True)", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tok = Tokenizer()", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "texts = myData[0].astype(str)", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "texts.values[0:10]", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 9, | |
"data": { | |
"text/plain": "array(['zero one two three four five six seven eight nine',\n 'eight nine zero one two three four five six seven',\n 'zero one two three four five six seven eight nine',\n 'eight nine zero one two three four five six seven',\n 'nine zero one two three four five six seven eight',\n 'six seven eight nine zero one two three four five',\n 'two three four five six seven eight nine zero one',\n 'zero one two three four five six seven eight nine',\n 'one two three four five six seven eight nine zero',\n 'six seven eight nine zero one two three four five'], dtype=object)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "data = tok.proc_all_mp(partition_by_cores(texts.values.astype(str)))", | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Now that we have the data, let's build a frequency map. This will tell us how many times each word was seen. Since we have 5 numbers (0-9) and are doing 100 sequences, all of these will be close to 500. Usually this will not be the case and this part will help filter out any words that are only seen a low number of time. " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "freq = Counter(p for o in data for p in o)", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "itos will be used to translate the number back into a string. We are also inserting an \\_eos_ token to signify the end of the string." | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "itos = [o for o,c in freq.most_common(10) if c > 2]\nfor i in [\"_eos_\", \"_pad_\", \"_unk_\"]:\n itos.insert(0, i)", | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "itos", | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 13, | |
"data": { | |
"text/plain": "['_unk_',\n '_pad_',\n '_eos_',\n 'zero',\n 'one',\n 'two',\n 'three',\n 'four',\n 'five',\n 'six',\n 'seven',\n 'eight',\n 'nine']" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "freq.most_common(10)", | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 14, | |
"data": { | |
"text/plain": "[('zero', 1000),\n ('one', 1000),\n ('two', 1000),\n ('three', 1000),\n ('four', 1000),\n ('five', 1000),\n ('six', 1000),\n ('seven', 1000),\n ('eight', 1000),\n ('nine', 1000)]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "itos", | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 15, | |
"data": { | |
"text/plain": "['_unk_',\n '_pad_',\n '_eos_',\n 'zero',\n 'one',\n 'two',\n 'three',\n 'four',\n 'five',\n 'six',\n 'seven',\n 'eight',\n 'nine']" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "The stoi variable will create the translator from the strings to the int versions. " | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "stoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})\nlen(itos)", | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 16, | |
"data": { | |
"text/plain": "13" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "The lambda:0 is telling this that if you don't know what the word is, give it a value of \"0\" which we know is tied to '_unk_' so translating it back, would replace that word with '_unk_'" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "unknownNumber = stoi[\"ten\"];print(\"unknownNumber idx: \" + str(unknownNumber))\nknownNumber = stoi[\"nine\"];print(\"knownNumber idx: \" + str(knownNumber))\nprint(itos[unknownNumber])\nprint(itos[knownNumber])", | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "unknownNumber idx: 0\nknownNumber idx: 12\n_unk_\nnine\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "All I'm doing here is feeding each of my numbers through and turning the string into an int using stoi[wordtotokenize]" | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tokenized_data = [[stoi[o] for o in i] for i in data]", | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"scrolled": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "for datanew in tokenized_data:\n datanew+=[2]", | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "texts.values[0:10]", | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 20, | |
"data": { | |
"text/plain": "array(['zero one two three four five six seven eight nine',\n 'eight nine zero one two three four five six seven',\n 'zero one two three four five six seven eight nine',\n 'eight nine zero one two three four five six seven',\n 'nine zero one two three four five six seven eight',\n 'six seven eight nine zero one two three four five',\n 'two three four five six seven eight nine zero one',\n 'zero one two three four five six seven eight nine',\n 'one two three four five six seven eight nine zero',\n 'six seven eight nine zero one two three four five'], dtype=object)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "tokenized_data[0:10]", | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 21, | |
"data": { | |
"text/plain": "[[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 2],\n [11, 12, 3, 4, 5, 6, 7, 8, 9, 10, 2],\n [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 2],\n [11, 12, 3, 4, 5, 6, 7, 8, 9, 10, 2],\n [12, 3, 4, 5, 6, 7, 8, 9, 10, 11, 2],\n [9, 10, 11, 12, 3, 4, 5, 6, 7, 8, 2],\n [5, 6, 7, 8, 9, 10, 11, 12, 3, 4, 2],\n [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 2],\n [4, 5, 6, 7, 8, 9, 10, 11, 12, 3, 2],\n [9, 10, 11, 12, 3, 4, 5, 6, 7, 8, 2]]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "PATH = Path(\"data/counterExample/\")", | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "em_sz,nh,nl = 8,16,2\nbptt=12\nbs=4\nopt_fn = optim.Adam#partial(optim.Adam, betas=(0.8, 0.99))", | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "combined_tokenized_data = np.concatenate(tokenized_data)", | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "dataloader = LanguageModelLoader(combined_tokenized_data, bs, bptt)", | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "#PATH\n#Pad_Idx\n#Number of tokens\n#dataloader - Training\n#dataloader - Validation (Should be different from Training)\nmodeldata = LanguageModelData(PATH,1,len(itos), dataloader, dataloader)", | |
"execution_count": 26, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "drops=np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7\nlearner = modeldata.get_model(opt_fn, em_sz, nh, nl,dropouti=drops[0], dropout=drops[1], wdrop=drops[2], dropoute=drops[3], dropouth=drops[4])\nlearner.metrics = [accuracy]\nlearner.unfreeze()", | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "learner", | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 28, | |
"data": { | |
"text/plain": "SequentialRNN(\n (0): RNN_Encoder(\n (encoder): Embedding(13, 8, padding_idx=1)\n (encoder_with_dropout): EmbeddingDropout(\n (embed): Embedding(13, 8, padding_idx=1)\n )\n (rnns): ModuleList(\n (0): WeightDrop(\n (module): LSTM(8, 16, dropout=0.105)\n )\n (1): WeightDrop(\n (module): LSTM(16, 8, dropout=0.105)\n )\n )\n (dropouti): LockedDropout(\n )\n (dropouths): ModuleList(\n (0): LockedDropout(\n )\n (1): LockedDropout(\n )\n )\n )\n (1): LinearDecoder(\n (decoder): Linear(in_features=8, out_features=13)\n (dropout): LockedDropout(\n )\n )\n)" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "learner.lr_find()", | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))", | |
"text/html": "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n<p>\n If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n that the widgets JavaScript is still loading. If this message persists, it\n likely means that the widgets JavaScript library is either not installed or\n not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n Widgets Documentation</a> for setup instructions.\n</p>\n<p>\n If you're reading this message in another frontend (for example, a static\n rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n it may mean that your frontend doesn't currently support widgets.\n</p>\n", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "db8411e0e2f8427cb08a2290774dc22d" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"text": "epoch trn_loss val_loss accuracy \n 0 2.477381 2.56495 0.010456 \n\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "learner.sched.plot()", | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "<matplotlib.figure.Figure at 0x7fb6fd698fd0>", | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "lr = 10e-2", | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"scrolled": true | |
}, | |
"cell_type": "code", | |
"source": "learner.fit(lr,10)", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))", | |
"text/html": "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n<p>\n If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n that the widgets JavaScript is still loading. If this message persists, it\n likely means that the widgets JavaScript library is either not installed or\n not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n Widgets Documentation</a> for setup instructions.\n</p>\n<p>\n If you're reading this message in another frontend (for example, a static\n rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n it may mean that your frontend doesn't currently support widgets.\n</p>\n", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "1effa6fa05534ce992e2ddfd08d1a6b9" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"text": "epoch trn_loss val_loss accuracy \n 0 0.877209 0.613849 0.829531 \n 1 0.756472 0.578578 0.826748 \n 2 0.802857 0.575006 0.832028 \n 3 0.831325 0.588339 0.823894 \n 4 0.738141 0.561355 0.825936 \n 5 0.916614 0.610922 0.829739 \n 6 0.945601 0.646737 0.824805 \n 7 0.929306 0.598057 0.826581 \n 8 0.836129 0.579624 0.827866 \n 9 0.76676 0.568329 0.830627 \n\n", | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 32, | |
"data": { | |
"text/plain": "[0.5683291, 0.83062704064344106]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "needPrediction = np.array([[5]])\nprobs = learner.model(V(needPrediction))", | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "probs[0]", | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 34, | |
"data": { | |
"text/plain": "Variable containing:\n\nColumns 0 to 7 \n -9.8337 -10.0553 3.9059 -0.1862 -5.8219 -2.0377 7.0377 3.7314\n\nColumns 8 to 12 \n 0.0088 -7.0428 -0.0959 -2.0290 -7.9542\n[torch.cuda.FloatTensor of size 1x13 (GPU 0)]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "Next, take the exponent of this number to convert into the likelyhood that the number is the next in the sequence" | |
}, | |
{ | |
"metadata": { | |
"scrolled": true, | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "total_percentage = 0\nfor i in to_np(F.softmax(probs[0])):\n total_percentage+=i\nprint(total_percentage)", | |
"execution_count": 35, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[ 0. 0. 0.0403 0.00067 0. 0.00011 0.92342 0.03384 0.00082 0. 0.00074 0.00011\n 0. ]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "for i in range(2,12):\n needPrediction = np.array([[i]])\n probs = learner.model(V(needPrediction))\n print(itos[i] + \"---->\" + itos[to_np(probs[0][-1].exp()).argmax()])", | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "_eos_---->zero\nzero---->one\none---->two\ntwo---->three\nthree---->four\nfour---->five\nfive---->six\nsix---->seven\nseven---->eight\neight---->nine\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "learner.fit(lr,10)", | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))", | |
"text/html": "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n<p>\n If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n that the widgets JavaScript is still loading. If this message persists, it\n likely means that the widgets JavaScript library is either not installed or\n not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n Widgets Documentation</a> for setup instructions.\n</p>\n<p>\n If you're reading this message in another frontend (for example, a static\n rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n it may mean that your frontend doesn't currently support widgets.\n</p>\n", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "de1f263bda984394bac265902aa495c0" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"text": "epoch trn_loss val_loss accuracy \n 0 0.771903 0.553757 0.827269 \n 1 0.806123 0.547836 0.829683 \n 2 0.791141 0.569081 0.82243 \n 3 0.804629 0.550345 0.825076 \n 4 0.763257 0.566674 0.827552 \n 5 0.754086 0.561353 0.825966 \n 6 1.055096 0.658756 0.818877 \n 7 1.028009 0.71496 0.813719 \n 8 0.927982 0.626914 0.820577 \n 9 0.781758 0.615392 0.827069 \n\n", | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 37, | |
"data": { | |
"text/plain": "[0.61539203, 0.82706903757756212]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "needPrediction = np.array([[5]])\nprobs = learner.model(V(needPrediction))\ntotal_percentage = 0\nfor i in to_np(F.softmax(probs[0])):\n total_percentage+=i\nprint(total_percentage)", | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[ 0. 0. 0.02286 0.00023 0. 0.00002 0.96559 0.01114 0.00006 0.00007 0. 0.00001\n 0. ]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "learner.fit(lr,10)", | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": "HBox(children=(IntProgress(value=0, description='Epoch', max=10), HTML(value='')))", | |
"text/html": "<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n<p>\n If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n that the widgets JavaScript is still loading. If this message persists, it\n likely means that the widgets JavaScript library is either not installed or\n not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n Widgets Documentation</a> for setup instructions.\n</p>\n<p>\n If you're reading this message in another frontend (for example, a static\n rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n it may mean that your frontend doesn't currently support widgets.\n</p>\n", | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "67d3b22fcc30484388c284e4b051651b" | |
} | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"text": "epoch trn_loss val_loss accuracy \n 0 0.842481 0.564832 0.823317 \n 1 0.816763 0.565989 0.829835 \n 2 0.828612 0.570158 0.827601 \n 3 0.841933 0.586223 0.826373 \n 4 0.818775 0.55845 0.834544 \n 5 0.881451 0.597358 0.827557 \n 6 0.792335 0.569497 0.826181 \n 7 0.878065 0.545774 0.828824 \n 8 0.826214 0.550325 0.833591 \n 9 0.779982 0.557214 0.822997 \n\n", | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"execution_count": 39, | |
"data": { | |
"text/plain": "[0.55721396, 0.82299670733903585]" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "needPrediction = np.array([[5]])\nprobs = learner.model(V(needPrediction))\ntotal_percentage = 0\nfor i in to_np(F.softmax(probs[0])):\n total_percentage+=i\nprint(total_percentage)", | |
"execution_count": 40, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "[ 0. 0. 0.02088 0.00024 0.00007 0.00195 0.95327 0.01969 0.00011 0.00009 0.00001 0.00351\n 0.00017]\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "for i in range(2,12):\n needPrediction = np.array([[i]])\n probs = learner.model(V(needPrediction))\n print(itos[i] + \"---->\" + itos[to_np(probs[0][-1].exp()).argmax()])", | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "_eos_---->four\nzero---->one\none---->two\ntwo---->three\nthree---->four\nfour---->five\nfive---->six\nsix---->seven\nseven---->eight\neight---->nine\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"_draft": { | |
"nbviewer_url": "https://gist.github.com/dda2ffa06ede42fd0562b33b02d088e7" | |
}, | |
"gist": { | |
"id": "dda2ffa06ede42fd0562b33b02d088e7", | |
"data": { | |
"description": "Counting Example for Language Model", | |
"public": true | |
} | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3", | |
"language": "python" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.6.3", | |
"mimetype": "text/x-python", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"pygments_lexer": "ipython3", | |
"nbconvert_exporter": "python", | |
"file_extension": ".py" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"state": { | |
"04dba83d703c4c97973b8da5ef017ead": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "success", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_1108f27619c04d0c9d1d9d4426a6d44a", | |
"max": 10, | |
"style": "IPY_MODEL_e10ddd3114834c63afeea0c9bfee32a1", | |
"value": 10 | |
} | |
}, | |
"1108f27619c04d0c9d1d9d4426a6d44a": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"149c1ac37c9b499f842f63d178d2bbc0": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_bfc4471dab35446fbd062cee50bc2847", | |
"style": "IPY_MODEL_2c78bfc16b424345822b8d9ebc5ac7c8", | |
"value": "100% 10/10 [00:05<00:00, 1.81it/s]" | |
} | |
}, | |
"171a3d23566a4916a2d6b38cba073108": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"17fab208c1f24c2da68669be65607e02": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"1a18ef3b0cf44c44b1b1c8fd5b3d2d76": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_04dba83d703c4c97973b8da5ef017ead", | |
"IPY_MODEL_f8e46f870e7e4adf8082d957777c08f0" | |
], | |
"layout": "IPY_MODEL_eab5cdadf3dc446aaf8c51f5a5cea696" | |
} | |
}, | |
"1d883cb12ddb470cab085f26452a0799": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_bba46c77a54442e89a600703618806ed", | |
"style": "IPY_MODEL_3850389523bd4061a875a4f37bcff987", | |
"value": " 0% 0/1 [00:00<?, ?it/s]" | |
} | |
}, | |
"2817a1e93ee84813a93fe3511ed5d7f7": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"2c78bfc16b424345822b8d9ebc5ac7c8": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"32ba039760b346578df3f09c14fe73ea": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_6b473e701eca4822bb8b8b99db3c4895", | |
"style": "IPY_MODEL_c083769380ef4a46a141905cce1186e6", | |
"value": "100% 15/15 [00:08<00:00, 1.75it/s]" | |
} | |
}, | |
"3850389523bd4061a875a4f37bcff987": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"3b46fb458f9b4dda8d895c773bc729ee": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"3f9fcf7d19264eb5b3f30e1e538ff6d8": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"46f844869b6847c29efc5d499d5d6acb": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_6b051936f29a4a14b6ab740d52c2f6e5", | |
"IPY_MODEL_d8f90f2fb1bc452688b37f06cc9bcb4e" | |
], | |
"layout": "IPY_MODEL_d1273af340b74f5ba133d5bb64969035" | |
} | |
}, | |
"49e99574d3464b559337da997c06b39b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"4fb7b8299d9043c0a8d3b192de40b57d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "success", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_5b72faa4fe8c45e3a3d22c70a6c5243f", | |
"max": 10, | |
"style": "IPY_MODEL_a36ebd91b6c848f3bddb4e3d217d6d53", | |
"value": 10 | |
} | |
}, | |
"5871c1f5cc7b457cbaa797c3efdfe905": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "success", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_866a435f08b542e5b68b6a58a71584d0", | |
"max": 15, | |
"style": "IPY_MODEL_3b46fb458f9b4dda8d895c773bc729ee", | |
"value": 15 | |
} | |
}, | |
"5b72faa4fe8c45e3a3d22c70a6c5243f": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"63871322b0d2493785a64b1cfaab2659": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_4fb7b8299d9043c0a8d3b192de40b57d", | |
"IPY_MODEL_149c1ac37c9b499f842f63d178d2bbc0" | |
], | |
"layout": "IPY_MODEL_e18888eafcc74d02ad321836ad293ca5" | |
} | |
}, | |
"65dae5ebd0c742d7a305399423c390ec": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"6b051936f29a4a14b6ab740d52c2f6e5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "success", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_171a3d23566a4916a2d6b38cba073108", | |
"max": 10, | |
"style": "IPY_MODEL_a5a133ee051e49479606ac4220b033ed", | |
"value": 10 | |
} | |
}, | |
"6b473e701eca4822bb8b8b99db3c4895": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"6f45fc011e0b41ed9e0947ed109198d3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_fb6d77fb2c5b4b5e9c9e870ef85951c5", | |
"IPY_MODEL_c53ab2d1e299414b8d0f099d80d0d802" | |
], | |
"layout": "IPY_MODEL_65dae5ebd0c742d7a305399423c390ec" | |
} | |
}, | |
"71988522634f4cf2bb83ac497f944924": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"739576ba78a4474194ebe2d056a02646": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"7892d81a5cef498abbea3e73bd73a9ff": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_c2f9d01f36b9435fad8f80abe783be47", | |
"IPY_MODEL_1d883cb12ddb470cab085f26452a0799" | |
], | |
"layout": "IPY_MODEL_919dad157b264202925bb3c01ae0c57d" | |
} | |
}, | |
"84091e10212a43b09db81bc8a97c92d5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"866a435f08b542e5b68b6a58a71584d0": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"919dad157b264202925bb3c01ae0c57d": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"959dbb9efd6941ca8bdf548f8de88d7c": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"a36ebd91b6c848f3bddb4e3d217d6d53": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"a5a133ee051e49479606ac4220b033ed": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"a9a611129ff6427d9b5cbc5c40de62bd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"bba46c77a54442e89a600703618806ed": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"bfc4471dab35446fbd062cee50bc2847": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"c083769380ef4a46a141905cce1186e6": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"c2f9d01f36b9435fad8f80abe783be47": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "danger", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_959dbb9efd6941ca8bdf548f8de88d7c", | |
"max": 1, | |
"style": "IPY_MODEL_3f9fcf7d19264eb5b3f30e1e538ff6d8" | |
} | |
}, | |
"c53ab2d1e299414b8d0f099d80d0d802": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_d966057d73304877be1afb753f8f87a0", | |
"style": "IPY_MODEL_2817a1e93ee84813a93fe3511ed5d7f7", | |
"value": "100% 100/100 [00:56<00:00, 1.78it/s]" | |
} | |
}, | |
"d1273af340b74f5ba133d5bb64969035": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"d2875ddf016045d2837f2083faa77fa3": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HBoxModel", | |
"state": { | |
"children": [ | |
"IPY_MODEL_5871c1f5cc7b457cbaa797c3efdfe905", | |
"IPY_MODEL_32ba039760b346578df3f09c14fe73ea" | |
], | |
"layout": "IPY_MODEL_17fab208c1f24c2da68669be65607e02" | |
} | |
}, | |
"d8f90f2fb1bc452688b37f06cc9bcb4e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_49e99574d3464b559337da997c06b39b", | |
"style": "IPY_MODEL_a9a611129ff6427d9b5cbc5c40de62bd", | |
"value": "100% 10/10 [00:05<00:00, 1.77it/s]" | |
} | |
}, | |
"d966057d73304877be1afb753f8f87a0": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"e10ddd3114834c63afeea0c9bfee32a1": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"e18888eafcc74d02ad321836ad293ca5": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"eab5cdadf3dc446aaf8c51f5a5cea696": { | |
"model_module": "@jupyter-widgets/base", | |
"model_module_version": "1.0.0", | |
"model_name": "LayoutModel", | |
"state": {} | |
}, | |
"f56541c7d04f4a6b9819665e2004fb64": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"description_width": "" | |
} | |
}, | |
"f8e46f870e7e4adf8082d957777c08f0": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "HTMLModel", | |
"state": { | |
"layout": "IPY_MODEL_739576ba78a4474194ebe2d056a02646", | |
"style": "IPY_MODEL_84091e10212a43b09db81bc8a97c92d5", | |
"value": "100% 10/10 [00:05<00:00, 1.79it/s]" | |
} | |
}, | |
"fb6d77fb2c5b4b5e9c9e870ef85951c5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_module_version": "1.1.0", | |
"model_name": "IntProgressModel", | |
"state": { | |
"bar_style": "success", | |
"description": "Epoch", | |
"layout": "IPY_MODEL_71988522634f4cf2bb83ac497f944924", | |
"style": "IPY_MODEL_f56541c7d04f4a6b9819665e2004fb64", | |
"value": 100 | |
} | |
} | |
}, | |
"version_major": 2, | |
"version_minor": 0 | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment