Last active
May 24, 2017 20:47
-
-
Save niazangels/841b46dfa429c3a8b4bc56d8a0a734f2 to your computer and use it in GitHub Desktop.
3-char RNN model mostly always predicts a space
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"# Lesson 6" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from theano.sandbox import cuda\n", | |
"cuda.use('gpu0')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"%matplotlib inline\n", | |
"import utils; reload(utils)\n", | |
"from utils import *\n", | |
"from __future__ import division, print_function" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"corpus length: 600901\n" | |
] | |
} | |
], | |
"source": [ | |
"path = get_file('nietzsche.txt', origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\")\n", | |
"# path = get_file('sherlock.txt', origin=\"https://sherlock-holm.es/stories/plain-text/cano.txt\")\n", | |
"\n", | |
"text = open(path).read()\n", | |
"\n", | |
"print('corpus length:', len(text))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"total chars: 86\n" | |
] | |
} | |
], | |
"source": [ | |
"chars = sorted(list(set(text)))\n", | |
"vocab_size = len(chars)+1\n", | |
"print('total chars:', vocab_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"chars.insert(0, \"\\0\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz'" | |
] | |
}, | |
"execution_count": 72, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"''.join(chars[1:-6])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"char_indices = dict((c, i) for i, c in enumerate(chars))\n", | |
"indices_char = dict((i, c) for i, c in enumerate(chars))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"idx = [char_indices[c] for c in text]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[40, 42, 29, 30, 25, 27, 29, 1, 1, 1, 43, 45, 40, 40, 39, 43, 33, 38, 31, 2, 73, 61, 54, 73, 2, 44, 71, 74, 73, 61, 2, 62, 72, 2, 54, 2, 76, 68, 66, 54, 67, 9, 9, 76, 61, 54, 73, 2, 73, 61, 58, 67, 24, 2, 33, 72, 2, 73, 61, 58, 71, 58, 2, 67, 68, 73, 2, 60, 71, 68, 74, 67, 57, 1, 59, 68, 71, 2, 72, 74, 72, 69, 58, 56, 73, 62, 67, 60, 2, 73, 61, 54, 73, 2, 54, 65, 65, 2, 69, 61, 62, 65, 68, 72, 68, 69, 61, 58, 71, 72, 8, 2, 62, 67, 2, 72, 68, 2, 59, 54, 71, 2, 54, 72, 2, 73, 61, 58, 78, 2, 61, 54, 75, 58, 2, 55, 58, 58, 67, 1, 57, 68, 60, 66, 54, 73, 62, 72, 73, 72, 8, 2, 61, 54, 75, 58, 2, 59, 54, 62, 65, 58, 57, 2, 73, 68, 2, 74, 67, 57, 58, 71, 72, 73, 54, 67, 57, 2, 76, 68, 66, 58, 67, 9, 9, 73, 61, 54, 73, 2, 73, 61, 58, 2, 73, 58, 71, 71, 62, 55]\n" | |
] | |
} | |
], | |
"source": [ | |
"print (idx[:200])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrib'" | |
] | |
}, | |
"execution_count": 76, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"text[:200]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"PREFACE\n", | |
"\n", | |
"\n", | |
"SUPPOSING that Truth is a woman--what then? Is there not ground\n", | |
"for suspecting that all philosophers, in so far as they have been\n", | |
"dogmatists, have failed to understand women--that the terrib\n" | |
] | |
} | |
], | |
"source": [ | |
"print (text[:200])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"# 3 Char Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"skip = 3\n", | |
"\n", | |
"c1_data = [idx[i+0] for i in range(0, len(idx)-1-skip, skip)]\n", | |
"c2_data = [idx[i+1] for i in range(0, len(idx)-1-skip, skip)]\n", | |
"c3_data = [idx[i+2] for i in range(0, len(idx)-1-skip, skip)]\n", | |
"c4_data = [idx[i+3] for i in range(0, len(idx)-1-skip, skip)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"x1 = np.stack(c1_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(200299,)" | |
] | |
}, | |
"execution_count": 80, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x1.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"<br>Our inputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"x1 = np.stack(c1_data)\n", | |
"x2 = np.stack(c2_data)\n", | |
"x3 = np.stack(c3_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"And output" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"y = np.stack(c4_data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"Let's check them out" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([40, 30, 29, 1]), array([42, 25, 1, 43]), array([29, 27, 1, 45]))" | |
] | |
}, | |
"execution_count": 83, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x1[:4], x2[:4], x3[:4]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([30, 29, 1, 40])" | |
] | |
}, | |
"execution_count": 84, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y[:4]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"((200299,), (200299,), (200299,), (200299,))" | |
] | |
}, | |
"execution_count": 85, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x1.shape, x2.shape, x3.shape, y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"n_fac = 42" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def embedding_input(name, n_in, n_out):\n", | |
" inp = Input(shape=(1,), dtype='int64', name=name)\n", | |
" emb = Embedding(n_in, n_out, input_length=1)(inp)\n", | |
" return inp, Flatten()(emb)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"c1_in, c1 = embedding_input('c1', vocab_size, n_fac)\n", | |
"c2_in, c2 = embedding_input('c2', vocab_size, n_fac)\n", | |
"c3_in, c3 = embedding_input('c3', vocab_size, n_fac)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"### Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"n_hidden = 256" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"dense_in = Dense(n_hidden, activation='relu')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"dense_hidden = Dense(n_hidden, activation='tanh')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"dense_out = Dense(vocab_size, activation='softmax')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 93, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"c1_hidden = dense_in(c1)\n", | |
"hidden_2 = dense_hidden(c1_hidden)\n", | |
"\n", | |
"c2_dense = dense_in(c2)\n", | |
"c2_hidden = merge([c2_dense, hidden_2])\n", | |
"hidden_3 = dense_hidden(c2_hidden)\n", | |
"\n", | |
"c3_dense = dense_in(c3)\n", | |
"c3_hidden = merge([c3_dense, hidden_2])\n", | |
"\n", | |
"c4_out = dense_out(c3_hidden)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 94, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = Model([c1_in, c2_in, c3_in], c4_out)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.optimizer.lr=0.000001" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 97, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/4\n", | |
"23s - loss: 4.4349\n", | |
"Epoch 2/4\n", | |
"31s - loss: 4.3801\n", | |
"Epoch 3/4\n", | |
"29s - loss: 4.2960\n", | |
"Epoch 4/4\n", | |
"30s - loss: 4.1677\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x7f6e459f6610>" | |
] | |
}, | |
"execution_count": 97, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 98, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_next(inp):\n", | |
" idxs = [char_indices[c] for c in inp]\n", | |
" arrs = [np.stack([i]) for i in idxs]\n", | |
" p = model.predict(arrs)\n", | |
" i = np.argmax(p)\n", | |
" return chars[i]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 102, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"' '" | |
] | |
}, | |
"execution_count": 102, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next('phi')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"' '" | |
] | |
}, | |
"execution_count": 103, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next('the')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 104, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"' '" | |
] | |
}, | |
"execution_count": 104, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next(' th')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"deletable": true, | |
"editable": true | |
}, | |
"source": [ | |
"That doesn't give us much information. Let's try to show the top 10 predictions and their probabilities." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_next_prob(inp):\n", | |
" idxs = [char_indices[c] for c in inp]\n", | |
" arrs = [np.array(i)[np.newaxis] for i in idxs]\n", | |
"\n", | |
" p = model.predict(arrs)\n", | |
" p_flat = np.squeeze(p)\n", | |
"\n", | |
" p_sorted = np.argsort(p)[0][::-1]\n", | |
" p_sorted_flat = np.squeeze(p_sorted)\n", | |
" \n", | |
"# i = np.argmax(p)\n", | |
" top_preds = list(np.array(chars)[p_sorted_flat][:10])\n", | |
" top_probs = list(np.array(p_flat)[p_sorted_flat][:10])\n", | |
" return (top_preds, top_probs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 'i', 'a', 'n', 't', 's', 'o', 'h', 'l'],\n", | |
" [0.022329709,\n", | |
" 0.017263755,\n", | |
" 0.016879607,\n", | |
" 0.016484482,\n", | |
" 0.01632045,\n", | |
" 0.016156379,\n", | |
" 0.015346508,\n", | |
" 0.014879305,\n", | |
" 0.014764636,\n", | |
" 0.01461644])" | |
] | |
}, | |
"execution_count": 106, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('phi')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 107, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 'n', 'a', 'i', 't', 's', 'd', 'o', 'h'],\n", | |
" [0.027977562,\n", | |
" 0.0207875,\n", | |
" 0.018497664,\n", | |
" 0.018191339,\n", | |
" 0.01810894,\n", | |
" 0.017898494,\n", | |
" 0.015798384,\n", | |
" 0.015642775,\n", | |
" 0.015531039,\n", | |
" 0.015266929])" | |
] | |
}, | |
"execution_count": 107, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob(' th')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 'n', 'i', 'a', 't', 's', 'h', 'l', 'o'],\n", | |
" [0.022597959,\n", | |
" 0.018363712,\n", | |
" 0.017289693,\n", | |
" 0.017275987,\n", | |
" 0.017116092,\n", | |
" 0.016938381,\n", | |
" 0.015594891,\n", | |
" 0.01518878,\n", | |
" 0.014814736,\n", | |
" 0.014639188])" | |
] | |
}, | |
"execution_count": 108, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('nno')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 109, | |
"metadata": { | |
"collapsed": true, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.optimizer.lr=0.01" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 110, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/4\n", | |
"27s - loss: 3.9857\n", | |
"Epoch 2/4\n", | |
"24s - loss: 3.7560\n", | |
"Epoch 3/4\n", | |
"24s - loss: 3.5167\n", | |
"Epoch 4/4\n", | |
"24s - loss: 3.3315\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x7f6e30655c90>" | |
] | |
}, | |
"execution_count": 110, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 111, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 't', 'i', 'a', 'n', 's', 'o', 'h', 'r'],\n", | |
" [0.12042212,\n", | |
" 0.05575224,\n", | |
" 0.042415172,\n", | |
" 0.040973555,\n", | |
" 0.040624034,\n", | |
" 0.038294815,\n", | |
" 0.032249872,\n", | |
" 0.030022236,\n", | |
" 0.027263785,\n", | |
" 0.026214881])" | |
] | |
}, | |
"execution_count": 111, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('She')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 112, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 't', 'i', 'a', 'n', 'o', 's', 'h', 'l'],\n", | |
" [0.08146771,\n", | |
" 0.046317004,\n", | |
" 0.037694421,\n", | |
" 0.037186079,\n", | |
" 0.035755284,\n", | |
" 0.035155497,\n", | |
" 0.028711507,\n", | |
" 0.028595582,\n", | |
" 0.025967052,\n", | |
" 0.023849837])" | |
] | |
}, | |
"execution_count": 112, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('Hol')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 't', 'i', 'a', 'n', 'h', 's', 'o', 'r'],\n", | |
" [0.11406414,\n", | |
" 0.05443234,\n", | |
" 0.04464009,\n", | |
" 0.041871846,\n", | |
" 0.040271211,\n", | |
" 0.038731921,\n", | |
" 0.034342367,\n", | |
" 0.031122563,\n", | |
" 0.03032265,\n", | |
" 0.02564851])" | |
] | |
}, | |
"execution_count": 113, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('Wat')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 115, | |
"metadata": { | |
"collapsed": false, | |
"deletable": true, | |
"editable": true, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"([' ', 'e', 't', 'a', 'i', 'n', 's', 'o', 'h', 'r'],\n", | |
" [0.13291049,\n", | |
" 0.060536038,\n", | |
" 0.044352487,\n", | |
" 0.044240393,\n", | |
" 0.042398464,\n", | |
" 0.039309599,\n", | |
" 0.033799037,\n", | |
" 0.031210093,\n", | |
" 0.029896479,\n", | |
" 0.026051341])" | |
] | |
}, | |
"execution_count": 115, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_next_prob('Doe')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's do a running prediction for 500 characters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 116, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"503 Fro \n" | |
] | |
} | |
], | |
"source": [ | |
"seed = 'Fro'\n", | |
"running_pred = seed\n", | |
"for i in range(500):\n", | |
" pred = get_next(seed)\n", | |
" running_pred += pred\n", | |
" seed = seed[1:] + pred\n", | |
"print (len(running_pred), running_pred)\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment