Skip to content

Instantly share code, notes, and snippets.

@niazangels
Last active May 24, 2017 20:47
Show Gist options
  • Save niazangels/841b46dfa429c3a8b4bc56d8a0a734f2 to your computer and use it in GitHub Desktop.
Save niazangels/841b46dfa429c3a8b4bc56d8a0a734f2 to your computer and use it in GitHub Desktop.
3-char RNN model mostly always predicts a space
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Lesson 6"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from theano.sandbox import cuda\n",
"cuda.use('gpu0')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using Theano backend.\n"
]
}
],
"source": [
"%matplotlib inline\n",
"import utils; reload(utils)\n",
"from utils import *\n",
"from __future__ import division, print_function"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"corpus length: 600901\n"
]
}
],
"source": [
"path = get_file('nietzsche.txt', origin=\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\")\n",
"# path = get_file('sherlock.txt', origin=\"https://sherlock-holm.es/stories/plain-text/cano.txt\")\n",
"\n",
"text = open(path).read()\n",
"\n",
"print('corpus length:', len(text))"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total chars: 86\n"
]
}
],
"source": [
"chars = sorted(list(set(text)))\n",
"vocab_size = len(chars)+1\n",
"print('total chars:', vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"chars.insert(0, \"\\0\")"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz'"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"''.join(chars[1:-6])"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"char_indices = dict((c, i) for i, c in enumerate(chars))\n",
"indices_char = dict((i, c) for i, c in enumerate(chars))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"idx = [char_indices[c] for c in text]"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[40, 42, 29, 30, 25, 27, 29, 1, 1, 1, 43, 45, 40, 40, 39, 43, 33, 38, 31, 2, 73, 61, 54, 73, 2, 44, 71, 74, 73, 61, 2, 62, 72, 2, 54, 2, 76, 68, 66, 54, 67, 9, 9, 76, 61, 54, 73, 2, 73, 61, 58, 67, 24, 2, 33, 72, 2, 73, 61, 58, 71, 58, 2, 67, 68, 73, 2, 60, 71, 68, 74, 67, 57, 1, 59, 68, 71, 2, 72, 74, 72, 69, 58, 56, 73, 62, 67, 60, 2, 73, 61, 54, 73, 2, 54, 65, 65, 2, 69, 61, 62, 65, 68, 72, 68, 69, 61, 58, 71, 72, 8, 2, 62, 67, 2, 72, 68, 2, 59, 54, 71, 2, 54, 72, 2, 73, 61, 58, 78, 2, 61, 54, 75, 58, 2, 55, 58, 58, 67, 1, 57, 68, 60, 66, 54, 73, 62, 72, 73, 72, 8, 2, 61, 54, 75, 58, 2, 59, 54, 62, 65, 58, 57, 2, 73, 68, 2, 74, 67, 57, 58, 71, 72, 73, 54, 67, 57, 2, 76, 68, 66, 58, 67, 9, 9, 73, 61, 54, 73, 2, 73, 61, 58, 2, 73, 58, 71, 71, 62, 55]\n"
]
}
],
"source": [
"print (idx[:200])"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrib'"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text[:200]"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PREFACE\n",
"\n",
"\n",
"SUPPOSING that Truth is a woman--what then? Is there not ground\n",
"for suspecting that all philosophers, in so far as they have been\n",
"dogmatists, have failed to understand women--that the terrib\n"
]
}
],
"source": [
"print (text[:200])"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# 3 Char Model"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"skip = 3\n",
"\n",
"c1_data = [idx[i+0] for i in range(0, len(idx)-1-skip, skip)]\n",
"c2_data = [idx[i+1] for i in range(0, len(idx)-1-skip, skip)]\n",
"c3_data = [idx[i+2] for i in range(0, len(idx)-1-skip, skip)]\n",
"c4_data = [idx[i+3] for i in range(0, len(idx)-1-skip, skip)]"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"x1 = np.stack(c1_data)"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200299,)"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"<br>Our inputs"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"x1 = np.stack(c1_data)\n",
"x2 = np.stack(c2_data)\n",
"x3 = np.stack(c3_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"And output"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"y = np.stack(c4_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Let's check them out"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"(array([40, 30, 29, 1]), array([42, 25, 1, 43]), array([29, 27, 1, 45]))"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[:4], x2[:4], x3[:4]"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([30, 29, 1, 40])"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[:4]"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"((200299,), (200299,), (200299,), (200299,))"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape, x2.shape, x3.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"n_fac = 42"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def embedding_input(name, n_in, n_out):\n",
" inp = Input(shape=(1,), dtype='int64', name=name)\n",
" emb = Embedding(n_in, n_out, input_length=1)(inp)\n",
" return inp, Flatten()(emb)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"c1_in, c1 = embedding_input('c1', vocab_size, n_fac)\n",
"c2_in, c2 = embedding_input('c2', vocab_size, n_fac)\n",
"c3_in, c3 = embedding_input('c3', vocab_size, n_fac)"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"### Model"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"n_hidden = 256"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_in = Dense(n_hidden, activation='relu')"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_hidden = Dense(n_hidden, activation='tanh')"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"dense_out = Dense(vocab_size, activation='softmax')"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"c1_hidden = dense_in(c1)\n",
"hidden_2 = dense_hidden(c1_hidden)\n",
"\n",
"c2_dense = dense_in(c2)\n",
"c2_hidden = merge([c2_dense, hidden_2])\n",
"hidden_3 = dense_hidden(c2_hidden)\n",
"\n",
"c3_dense = dense_in(c3)\n",
"c3_hidden = merge([c3_dense, hidden_2])\n",
"\n",
"c4_out = dense_out(c3_hidden)"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model = Model([c1_in, c2_in, c3_in], c4_out)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam())"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model.optimizer.lr=0.000001"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"23s - loss: 4.4349\n",
"Epoch 2/4\n",
"31s - loss: 4.3801\n",
"Epoch 3/4\n",
"29s - loss: 4.2960\n",
"Epoch 4/4\n",
"30s - loss: 4.1677\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f6e459f6610>"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def get_next(inp):\n",
" idxs = [char_indices[c] for c in inp]\n",
" arrs = [np.stack([i]) for i in idxs]\n",
" p = model.predict(arrs)\n",
" i = np.argmax(p)\n",
" return chars[i]"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"' '"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next('phi')"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"' '"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next('the')"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"' '"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next(' th')"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"That doesn't give us much information. Let's try to show the top 10 predictions and their probabilities."
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def get_next_prob(inp):\n",
" idxs = [char_indices[c] for c in inp]\n",
" arrs = [np.array(i)[np.newaxis] for i in idxs]\n",
"\n",
" p = model.predict(arrs)\n",
" p_flat = np.squeeze(p)\n",
"\n",
" p_sorted = np.argsort(p)[0][::-1]\n",
" p_sorted_flat = np.squeeze(p_sorted)\n",
" \n",
"# i = np.argmax(p)\n",
" top_preds = list(np.array(chars)[p_sorted_flat][:10])\n",
" top_probs = list(np.array(p_flat)[p_sorted_flat][:10])\n",
" return (top_preds, top_probs)"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 'i', 'a', 'n', 't', 's', 'o', 'h', 'l'],\n",
" [0.022329709,\n",
" 0.017263755,\n",
" 0.016879607,\n",
" 0.016484482,\n",
" 0.01632045,\n",
" 0.016156379,\n",
" 0.015346508,\n",
" 0.014879305,\n",
" 0.014764636,\n",
" 0.01461644])"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('phi')"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 'n', 'a', 'i', 't', 's', 'd', 'o', 'h'],\n",
" [0.027977562,\n",
" 0.0207875,\n",
" 0.018497664,\n",
" 0.018191339,\n",
" 0.01810894,\n",
" 0.017898494,\n",
" 0.015798384,\n",
" 0.015642775,\n",
" 0.015531039,\n",
" 0.015266929])"
]
},
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob(' th')"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 'n', 'i', 'a', 't', 's', 'h', 'l', 'o'],\n",
" [0.022597959,\n",
" 0.018363712,\n",
" 0.017289693,\n",
" 0.017275987,\n",
" 0.017116092,\n",
" 0.016938381,\n",
" 0.015594891,\n",
" 0.01518878,\n",
" 0.014814736,\n",
" 0.014639188])"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('nno')"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"model.optimizer.lr=0.01"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"27s - loss: 3.9857\n",
"Epoch 2/4\n",
"24s - loss: 3.7560\n",
"Epoch 3/4\n",
"24s - loss: 3.5167\n",
"Epoch 4/4\n",
"24s - loss: 3.3315\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f6e30655c90>"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit([x1, x2, x3], y, batch_size=64, nb_epoch=4, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 't', 'i', 'a', 'n', 's', 'o', 'h', 'r'],\n",
" [0.12042212,\n",
" 0.05575224,\n",
" 0.042415172,\n",
" 0.040973555,\n",
" 0.040624034,\n",
" 0.038294815,\n",
" 0.032249872,\n",
" 0.030022236,\n",
" 0.027263785,\n",
" 0.026214881])"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('She')"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 't', 'i', 'a', 'n', 'o', 's', 'h', 'l'],\n",
" [0.08146771,\n",
" 0.046317004,\n",
" 0.037694421,\n",
" 0.037186079,\n",
" 0.035755284,\n",
" 0.035155497,\n",
" 0.028711507,\n",
" 0.028595582,\n",
" 0.025967052,\n",
" 0.023849837])"
]
},
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('Hol')"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 't', 'i', 'a', 'n', 'h', 's', 'o', 'r'],\n",
" [0.11406414,\n",
" 0.05443234,\n",
" 0.04464009,\n",
" 0.041871846,\n",
" 0.040271211,\n",
" 0.038731921,\n",
" 0.034342367,\n",
" 0.031122563,\n",
" 0.03032265,\n",
" 0.02564851])"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('Wat')"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"([' ', 'e', 't', 'a', 'i', 'n', 's', 'o', 'h', 'r'],\n",
" [0.13291049,\n",
" 0.060536038,\n",
" 0.044352487,\n",
" 0.044240393,\n",
" 0.042398464,\n",
" 0.039309599,\n",
" 0.033799037,\n",
" 0.031210093,\n",
" 0.029896479,\n",
" 0.026051341])"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_next_prob('Doe')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's do a running prediction for 500 characters"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"503 Fro \n"
]
}
],
"source": [
"seed = 'Fro'\n",
"running_pred = seed\n",
"for i in range(500):\n",
" pred = get_next(seed)\n",
" running_pred += pred\n",
" seed = seed[1:] + pred\n",
"print (len(running_pred), running_pred)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment