Last active
November 10, 2017 03:55
-
-
Save zicklag/31898ce14b4852a6a6f39cf9b849df94 to your computer and use it in GitHub Desktop.
Breakdown of the Keras Text Generation Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Text Generation using Keras\n", | |
"\n", | |
"This is my breakdown of a [text generation example](https://github.com/fchollet/keras/blob/master/examples/imdb_lstm.py) from the [Keras](https://github.com/fchollet/keras) library. I don't take credit for any of the code here. Pretty much all of it was copied from the example. Also, although I am very interested in machine learning, I'm no authority on the subject, I'm still figuring this stuff out. :)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras.models import Sequential\n", | |
"from keras.layers import Dense, Activation\n", | |
"from keras.layers import LSTM\n", | |
"from keras.optimizers import RMSprop\n", | |
"from keras.utils.data_utils import get_file\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import sys" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Raw Training Data\n", | |
"\n", | |
"First you get the text that will be used to train the network. The example suggests ~1M characters. This is just an example from the preface of *Journey to the Center of the Earth*." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"text = \"\"\"Iceland, the starting point of the marvellous underground journey imagined in this volume, is\n", | |
"invested at the present time with a painful interest in consequence of the disastrous eruptions last\n", | |
"Easter Day, which covered with lava and ashes the poor and scanty vegetation upon which four\n", | |
"thousand persons were partly dependent for the means of subsistence. For a long time to come the\n", | |
"natives of that interesting island, who cleave to their desert home with all that amor patriae which is\n", | |
"so much more easily understood than explained, will look, and look not in vain, for the help of those\n", | |
"on whom fall the smiles of a kindlier sun in regions not torn by earthquakes nor blasted and ravaged\n", | |
"by volcanic fires. Will the readers of this little book, who, are gifted with the means of indulging in\n", | |
"the luxury of extended beneficence, remember the distress of their brethren in the far north, whom\n", | |
"distance has not barred from the claim of being counted our \"neighbours\"? And whatever their humane\n", | |
"feelings may prompt them to bestow will be gladly added to the Mansion-House Iceland Relief Fund.\"\"\".lower()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"corpus length: 1092\n" | |
] | |
} | |
], | |
"source": [ | |
"print('corpus length:', len(text))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"chars = sorted(list(set(text)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"total chars: 32\n" | |
] | |
} | |
], | |
"source": [ | |
"print('total chars:', len(chars))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Character Mapping\n", | |
"\n", | |
"This code maps each character in the corpus to an integer and also creates a dictionary for mapping the other direction." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'o': 21, 'y': 31, 'w': 29, 'f': 12, 'j': 16, '\\n': 0, 'h': 14, 'g': 13, 's': 25, 'a': 7, '?': 6, 'k': 17, 'b': 8, 'u': 27, 'i': 15, '\"': 2, 'p': 22, 'e': 11, ',': 3, 'x': 30, 'v': 28, '.': 5, 'n': 20, ' ': 1, 'r': 24, '-': 4, 'c': 9, 'd': 10, 'm': 19, 'q': 23, 'l': 18, 't': 26}\n", | |
"{0: '\\n', 1: ' ', 2: '\"', 3: ',', 4: '-', 5: '.', 6: '?', 7: 'a', 8: 'b', 9: 'c', 10: 'd', 11: 'e', 12: 'f', 13: 'g', 14: 'h', 15: 'i', 16: 'j', 17: 'k', 18: 'l', 19: 'm', 20: 'n', 21: 'o', 22: 'p', 23: 'q', 24: 'r', 25: 's', 26: 't', 27: 'u', 28: 'v', 29: 'w', 30: 'x', 31: 'y'}\n" | |
] | |
} | |
], | |
"source": [ | |
"char_indices = dict((c, i) for i, c in enumerate(chars))\n", | |
"print(char_indices)\n", | |
"indices_char = dict((i, c) for i, c in enumerate(chars))\n", | |
"print(indices_char)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training Data Records\n", | |
"\n", | |
"This creates a list of x,y pairs that will be used to train the network. The goal is to predict the next letter, the y value, based on the list of letters before it, the x value. The sentences overlap by `maxlen - step`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"x: iceland, the starting point of the marve\n", | |
"y: l\n", | |
"x: land, the starting point of the marvello\n", | |
"y: u\n", | |
"x: d, the starting point of the marvellous \n", | |
"y: u\n", | |
"x: the starting point of the marvellous und\n", | |
"y: e\n", | |
"x: starting point of the marvellous underg\n", | |
"y: r\n", | |
"x: arting point of the marvellous undergrou\n", | |
"y: n\n", | |
"x: ing point of the marvellous underground \n", | |
"y: j\n" | |
] | |
} | |
], | |
"source": [ | |
"maxlen = 40\n", | |
"step = 3\n", | |
"sentences = []\n", | |
"next_chars = []\n", | |
"for i in range(0, len(text) - maxlen, step):\n", | |
" # Print the first 7 records\n", | |
" if i < 7*step:\n", | |
" print('x:',text[i:i+maxlen])\n", | |
" print('y:',text[i + maxlen])\n", | |
" sentences.append(text[i: i + maxlen])\n", | |
" next_chars.append(text[i + maxlen])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Vectorize Training Data\n", | |
"\n", | |
"This part gets a little complicated to think about. Here we have to represent each sentence from the input value as an array. The `np.zeros` just creates an array of the given shape that is filled with zeroes or, in this case, `False`'s. Line 1 creates a 3 dimensional array that will represent the x values and line 2 creates an array for the y values. The first dimension of each is the index of the training record. That means that `(x[0], y[0])` would be the first x,y pair.\n", | |
"\n", | |
"The x value for each training record is a 2 dimensional array representing the sentence. The horizontal axis of the array is the space in the sentence and the vertical axis is the character that could go there. For example, \"hello world\" would look like this:\n", | |
"\n", | |
" | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |\n", | |
"---|---|---|---|---|---|---|---|---|---|---|----|\n", | |
"' '| 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |\n", | |
"'d'| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |\n", | |
"'e'| 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |\n", | |
"'h'| 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |\n", | |
"'l'| 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |\n", | |
"'o'| 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |\n", | |
"'r'| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |\n", | |
"'w'| 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |\n", | |
"\n", | |
"The code sets the `dtype=np.bool` though so instead of 1's and 0's it uses `True`'s and `False`'s.\n", | |
"\n", | |
"The y value of the pair is just a single dimensional array representing the character that comes next. This is what it would look like if the letter 'l' came next.\n", | |
"\n", | |
" | ' ' | d | e | h | l | o | r | w |\n", | |
"--|-----|---|---|---|---|---|---|---|\n", | |
" | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)\n", | |
"y = np.zeros((len(sentences), len(chars)), dtype=np.bool)\n", | |
"for i, sentence in enumerate(sentences):\n", | |
" for t, char in enumerate(sentence):\n", | |
" x[i, t, char_indices[char]] = 1\n", | |
" y[i, char_indices[next_chars[i]]] = 1" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Build Model\n", | |
"\n", | |
"The next step is to build the model that we will use to make predictions on the training data. Keras makes this really easy to do. You can include the components you want to use from the library and connect them easily.\n", | |
"\n", | |
"I'm not 100% sure how this works but here's what I gather:\n", | |
"\n", | |
"The `Squential` model allows you to chain different layers together. He starts it with a Long/ShortTermMemor neural network that is configured with the x input dimensions that we use for the training input array. ( I think the `128` is the batch size. ) He then adds a fully-connected, a.k.a. \"Dense\" layer to it that acts as the output layer and thus has the dimensions of the output, `len(chars)`. Last he adds the activation layer as \"softmax\", which I know will always result in numbers that add up to 1 and can therefore be loosely interpreted as probabilities.\n", | |
"\n", | |
"The output of the network if you only had four different letters could look like this then:\n", | |
"\n", | |
" a | b | c | d |\n", | |
"---|---|---|---|\n", | |
".25|.10|.60|.05|\n", | |
"\n", | |
"In this example, \"c\" would be the predicted letter becuase it has the highest probability out of all of the letters." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = Sequential()\n", | |
"model.add(LSTM(128, input_shape=(maxlen, len(chars))))\n", | |
"model.add(Dense(len(chars)))\n", | |
"model.add(Activation('softmax'))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Next we chose an optimizer and loss function and compile the model. I don't know anything about the specific choices that were made in the example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer = RMSprop(lr=0.01)\n", | |
"model.compile(loss='categorical_crossentropy', optimizer=optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Probability selector\n", | |
"\n", | |
"Here is a helper function that he uses to select which character should be chosen based on the probability array that will be outputed by the model. I really have know idea what the logic is behind this math. The return value of the function should be the index of the character predicted to come next in the sentence." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def sample(preds, temperature=1.0):\n", | |
" # helper function to sample an index from a probability array\n", | |
" preds = np.asarray(preds).astype('float64')\n", | |
" preds = np.log(preds) / temperature\n", | |
" exp_preds = np.exp(preds)\n", | |
" preds = exp_preds / np.sum(exp_preds)\n", | |
" probas = np.random.multinomial(1, preds, 1)\n", | |
" return np.argmax(probas)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training the Model\n", | |
"\n", | |
"Here's the *long* step.\n", | |
"\n", | |
"...To be explained" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.4951 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.5259 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.4541 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.3347 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.3519 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.2894 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1903 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1598 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1640 \n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 10\n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1787 \n", | |
"\n", | |
"----- diversity: 0.2\n", | |
"----- Generating with seed: \"f the disastrous eruptions last\n", | |
"easter d\"\n", | |
"f the disastrous eruptions last\n", | |
"easter de the caine wit hl pery ert ho che he peand jart he pelrenderthe he peorthe ceetr che redestre horeand the caine wir lor and whom and whom llot the loo the ce.ten wiot lore ant tor lort lort ll the\n", | |
"\n", | |
"----- diversity: 0.5\n", | |
"----- Generating with seed: \"f the disastrous eruptions last\n", | |
"easter d\"\n", | |
"f the disastrous eruptions last\n", | |
"easter ded who aol whot heme tin lor ano the herpe tot wlbllsyr anot tor lort h hll and and ano homededed the coof llbt nn the he peardeanere who and and who a t llededededed thrt lorthe loo the ceetr\n", | |
"\n", | |
"----- diversity: 1.0\n", | |
"----- Generating with seed: \"f the disastrous eruptions last\n", | |
"easter d\"\n", | |
"f the disastrous eruptions last\n", | |
"easter d iueorr\"?edere he p or de penthime int wfok the cor coin hn che celag mulls cotthn wiichl re he thr ceiten not lart qne wimth faisesren wit tl llpre\n", | |
"ntrtheme.end the h isagrrymeots tom haimexy a t l\n", | |
"\n", | |
"----- diversity: 1.2\n", | |
"----- Generating with seed: \"f the disastrous eruptions last\n", | |
"easter d\"\n", | |
"f the disastrous eruptions last\n", | |
"easter des be llereasincheara\"n jhethy peday,rsaanr inc cort tor he peakns bl bey esdberttht cernd d\n", | |
" temes wftb aslededed whe thine if blart d ourly.ar hnc elrt he hel gncelirt he helpea\n", | |
"tn themresen the l\n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.2218 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1424 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.0783 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.0602 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.0449 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.0487 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.0373 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.1062 \n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.4892 \n", | |
"\n", | |
"--------------------------------------------------\n", | |
"Iteration 20\n", | |
"Epoch 1/1\n", | |
"351/351 [==============================] - 1s - loss: 0.5427 \n", | |
"\n", | |
"----- diversity: 0.2\n", | |
"----- Generating with seed: \" the starting point of the marvellous un\"\n", | |
" the starting point of the marvellous unde re tho ant carandedstert ho chein ho the ce.teo wime lorthe dedted the cedtrthe cant took ant ce the cante tom dlot the cante coltrthe d de tho ant corthe redtort ho colin ho the coot the redter\n", | |
"\n", | |
"----- diversity: 0.5\n", | |
"----- Generating with seed: \" the starting point of the marvellous un\"\n", | |
" the starting point of the marvellous undeint th che red aot wirt lort hince.he whoc lookthe dedtent the aot cortre he thich fortthe celir hl caede wiot lorthe ce.took agicedirthe cante col the ce.tio lorthe dedtett tookredeirthe dlitt to\n", | |
"\n", | |
"----- diversity: 1.0\n", | |
"----- Generating with seed: \" the starting point of the marvellous un\"\n", | |
" the starting point of the marvellous unde son wime lorthq coltrt rertheis iot lled and and asol the chemederthh ceyin wiot and ant the dainn ioc and and and and anot afth ls the caian wiot lort lomt me tho d foo whoc loot the d th wlok wi\n", | |
"\n", | |
"----- diversity: 1.2\n", | |
"----- Generating with seed: \" the starting point of the marvellous un\"\n", | |
" the starting point of the marvellous undeint tor the dedkttho cant canc\n", | |
".heceand the ce.he whomeandejn the caepn ha the daian pourthy inli\"e tor wioc loltere.ne th lhe wiot inc lirt lorthqie thrt aook ande antemerton thr chets iot coet din\n" | |
] | |
} | |
], | |
"source": [ | |
"for iteration in range(1, 21):\n", | |
" # Print every fifth training sample\n", | |
" if iteration % 10 == 0:\n", | |
" should_print = True\n", | |
" else:\n", | |
" should_print = False\n", | |
" \n", | |
" if should_print:\n", | |
" print()\n", | |
" print('-' * 50)\n", | |
" print('Iteration', iteration)\n", | |
" model.fit(x, y,\n", | |
" batch_size=128,\n", | |
" epochs=1)\n", | |
"\n", | |
" if should_print:\n", | |
" start_index = random.randint(0, len(text) - maxlen - 1)\n", | |
"\n", | |
" for diversity in [0.2, 0.5, 1.0, 1.2]:\n", | |
" if should_print:\n", | |
" print()\n", | |
" print('----- diversity:', diversity)\n", | |
"\n", | |
" generated = ''\n", | |
" sentence = text[start_index: start_index + maxlen]\n", | |
" generated += sentence\n", | |
" print('----- Generating with seed: \"' + sentence + '\"')\n", | |
" sys.stdout.write(generated)\n", | |
"\n", | |
" for i in range(200):\n", | |
" x_pred = np.zeros((1, maxlen, len(chars)))\n", | |
" for t, char in enumerate(sentence):\n", | |
" x_pred[0, t, char_indices[char]] = 1.\n", | |
"\n", | |
" preds = model.predict(x_pred, verbose=0)[0]\n", | |
" next_index = sample(preds, diversity)\n", | |
" next_char = indices_char[next_index]\n", | |
"\n", | |
" generated += next_char\n", | |
" sentence = sentence[1:] + next_char\n", | |
"\n", | |
" sys.stdout.write(next_char)\n", | |
" sys.stdout.flush()\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment