Skip to content

Instantly share code, notes, and snippets.

@gangtao
Created March 14, 2019 16:51
Show Gist options
  • Save gangtao/b7de89d95293599e98a0b79f4a64f0de to your computer and use it in GitHub Desktop.
Save gangtao/b7de89d95293599e98a0b79f4a64f0de to your computer and use it in GitHub Desktop.
A Simple Neural Language Model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download the text file, remove the first 110 lines of introduction"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 785k 100 785k 0 0 795k 0 --:--:-- --:--:-- --:--:-- 795k\n"
]
}
],
"source": [
"! curl http://www.gutenberg.org/files/98/98-0.txt > orginal.txt"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"! sed -e '1,110d' < orginal.txt > orginal_clean.txt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### prepare the text data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load text"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I. The Period\n",
"\n",
"\n",
"It was the best of times,\n",
"it was the worst of times,\n",
"it was the age of wisdom,\n",
"it was the age of foolishness,\n",
"it was the epoch of belief,\n",
"it was the epoch of incredulity,\n",
"it was the se\n"
]
}
],
"source": [
"# load doc into memory\n",
"def load_doc(filename):\n",
"\t# open the file as read only\n",
"\tfile = open(filename, 'r')\n",
"\t# read all text\n",
"\ttext = file.read()\n",
"\t# close the file\n",
"\tfile.close()\n",
"\treturn text\n",
"\n",
"in_filename = 'orginal_clean.txt'\n",
"doc = load_doc(in_filename)\n",
"print(doc[:200])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"clean text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['i', 'the', 'period', 'it', 'was', 'the', 'best', 'of', 'times', 'it', 'was', 'the', 'worst', 'of', 'times', 'it', 'was', 'the', 'age', 'of']\n",
"Total Tokens: 133661\n",
"Unique Tokens: 10290\n"
]
}
],
"source": [
"import string\n",
"# turn a doc into clean tokens\n",
"def clean_doc(doc):\n",
"\t# replace '--' with a space ' '\n",
"\tdoc = doc.replace('--', ' ')\n",
"\t# split into tokens by white space\n",
"\ttokens = doc.split()\n",
"\t# remove punctuation from each token\n",
"\ttable = str.maketrans('', '', string.punctuation)\n",
"\ttokens = [w.translate(table) for w in tokens]\n",
"\t# remove remaining tokens that are not alphabetic\n",
"\ttokens = [word for word in tokens if word.isalpha()]\n",
"\t# make lower case\n",
"\ttokens = [word.lower() for word in tokens]\n",
"\treturn tokens\n",
"\n",
"# clean document\n",
"tokens = clean_doc(doc)\n",
"print(tokens[:20])\n",
"print('Total Tokens: %d' % len(tokens))\n",
"print('Unique Tokens: %d' % len(set(tokens)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use a sliding window size of 50 to generate the sequence"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total Sequences: 133610\n"
]
}
],
"source": [
"# organize into sequences of tokens\n",
"seq_length = 50\n",
"length = seq_length + 1\n",
"sequences = list()\n",
"for i in range(length, len(tokens)):\n",
"\t# select sequence of tokens\n",
"\tseq = tokens[i-length:i]\n",
"\t# convert into a line\n",
"\tline = ' '.join(seq)\n",
"\t# store\n",
"\tsequences.append(line)\n",
"print('Total Sequences: %d' % len(sequences))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# save tokens to file, one dialog per line\n",
"def save_doc(lines, filename):\n",
"\tdata = '\\n'.join(lines)\n",
"\tfile = open(filename, 'w')\n",
"\tfile.write(data)\n",
"\tfile.close()\n",
"\n",
"# save sequences to file\n",
"out_filename = 'twocities_sequences.txt'\n",
"save_doc(sequences, out_filename)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"i the period it was the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness\r\n",
"the period it was the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it\r\n",
"period it was the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was\r\n",
"it was the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the\r\n",
"was the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring\r\n",
"the best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring of\r\n",
"best of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring of hope\r\n",
"of times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring of hope it\r\n",
"times it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring of hope it was\r\n",
"it was the worst of times it was the age of wisdom it was the age of foolishness it was the epoch of belief it was the epoch of incredulity it was the season of light it was the season of darkness it was the spring of hope it was the\r\n"
]
}
],
"source": [
"! head twocities_sequences.txt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load sequence"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# load doc into memory\n",
"def load_doc(filename):\n",
"\t# open the file as read only\n",
"\tfile = open(filename, 'r')\n",
"\t# read all text\n",
"\ttext = file.read()\n",
"\t# close the file\n",
"\tfile.close()\n",
"\treturn text\n",
"\n",
"# load\n",
"in_filename = out_filename\n",
"doc = load_doc(in_filename)\n",
"lines = doc.split('\\n')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"encode sequence"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# integer encode sequences of words\n",
"tokenizer = tf.keras.preprocessing.text.Tokenizer()\n",
"tokenizer.fit_on_texts(lines)\n",
"sequences = tokenizer.texts_to_sequences(lines)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10291"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# vocabulary size\n",
"vocab_size = len(tokenizer.word_index) + 1\n",
"vocab_size"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# separate into input and output\n",
"sequences = np.array(sequences)\n",
"X, y = sequences[:,:-1], sequences[:,-1]\n",
"y = tf.keras.utils.to_categorical(y, num_classes=vocab_size)\n",
"# seq_length = X.shape[1]\n",
"# seq_length"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"build a neural network work for language model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding (Embedding) (None, 50, 50) 514550 \n",
"_________________________________________________________________\n",
"unified_lstm (UnifiedLSTM) (None, 50, 100) 60400 \n",
"_________________________________________________________________\n",
"unified_lstm_1 (UnifiedLSTM) (None, 100) 80400 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 100) 10100 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 10291) 1039391 \n",
"=================================================================\n",
"Total params: 1,704,841\n",
"Trainable params: 1,704,841\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"embedding_dims = 50\n",
"lstm_units = 100\n",
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Embedding(vocab_size, embedding_dims, input_length=seq_length))\n",
"model.add(tf.keras.layers.LSTM(lstm_units, return_sequences=True))\n",
"model.add(tf.keras.layers.LSTM(lstm_units))\n",
"model.add(tf.keras.layers.Dense(lstm_units, activation='relu'))\n",
"model.add(tf.keras.layers.Dense(vocab_size, activation='softmax'))\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"train"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"133610/133610 [==============================] - 164s 1ms/sample - loss: 6.5933 - accuracy: 0.0632\n",
"Epoch 2/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 6.1577 - accuracy: 0.0834\n",
"Epoch 3/100\n",
"133610/133610 [==============================] - 169s 1ms/sample - loss: 5.9096 - accuracy: 0.1038\n",
"Epoch 4/100\n",
"133610/133610 [==============================] - 159s 1ms/sample - loss: 5.7525 - accuracy: 0.1144\n",
"Epoch 5/100\n",
"133610/133610 [==============================] - 158s 1ms/sample - loss: 5.6315 - accuracy: 0.1210\n",
"Epoch 6/100\n",
"133610/133610 [==============================] - 159s 1ms/sample - loss: 5.5271 - accuracy: 0.1254\n",
"Epoch 7/100\n",
"133610/133610 [==============================] - 158s 1ms/sample - loss: 5.4307 - accuracy: 0.1300\n",
"Epoch 8/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 5.3411 - accuracy: 0.1341\n",
"Epoch 9/100\n",
"133610/133610 [==============================] - 158s 1ms/sample - loss: 5.2495 - accuracy: 0.1394\n",
"Epoch 10/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 5.1597 - accuracy: 0.1439\n",
"Epoch 11/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 5.0711 - accuracy: 0.1483\n",
"Epoch 12/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 4.9867 - accuracy: 0.1516\n",
"Epoch 13/100\n",
"133610/133610 [==============================] - 158s 1ms/sample - loss: 4.9080 - accuracy: 0.1546\n",
"Epoch 14/100\n",
"133610/133610 [==============================] - 156s 1ms/sample - loss: 4.8309 - accuracy: 0.1579\n",
"Epoch 15/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 4.7582 - accuracy: 0.1610\n",
"Epoch 16/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 4.6913 - accuracy: 0.1637\n",
"Epoch 17/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 4.6267 - accuracy: 0.1666\n",
"Epoch 18/100\n",
"133610/133610 [==============================] - 164s 1ms/sample - loss: 4.5645 - accuracy: 0.1705\n",
"Epoch 19/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 4.5068 - accuracy: 0.1737\n",
"Epoch 20/100\n",
"133610/133610 [==============================] - 168s 1ms/sample - loss: 4.4520 - accuracy: 0.1773\n",
"Epoch 21/100\n",
"133610/133610 [==============================] - 166s 1ms/sample - loss: 4.4006 - accuracy: 0.1808\n",
"Epoch 22/100\n",
"133610/133610 [==============================] - 159s 1ms/sample - loss: 4.3531 - accuracy: 0.1843\n",
"Epoch 23/100\n",
"133610/133610 [==============================] - 153s 1ms/sample - loss: 4.3071 - accuracy: 0.1884\n",
"Epoch 24/100\n",
"133610/133610 [==============================] - 155s 1ms/sample - loss: 4.2630 - accuracy: 0.1919\n",
"Epoch 25/100\n",
"133610/133610 [==============================] - 153s 1ms/sample - loss: 4.2221 - accuracy: 0.1967\n",
"Epoch 26/100\n",
"133610/133610 [==============================] - 156s 1ms/sample - loss: 4.1823 - accuracy: 0.1997\n",
"Epoch 27/100\n",
"133610/133610 [==============================] - 168s 1ms/sample - loss: 4.1431 - accuracy: 0.2035\n",
"Epoch 28/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 4.1073 - accuracy: 0.2066\n",
"Epoch 29/100\n",
"133610/133610 [==============================] - 171s 1ms/sample - loss: 4.0709 - accuracy: 0.2117\n",
"Epoch 30/100\n",
"133610/133610 [==============================] - 169s 1ms/sample - loss: 4.0354 - accuracy: 0.2145\n",
"Epoch 31/100\n",
"133610/133610 [==============================] - 182s 1ms/sample - loss: 4.0005 - accuracy: 0.2178\n",
"Epoch 32/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 3.9677 - accuracy: 0.2214\n",
"Epoch 33/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 3.9384 - accuracy: 0.2259\n",
"Epoch 34/100\n",
"133610/133610 [==============================] - 168s 1ms/sample - loss: 3.9051 - accuracy: 0.2295\n",
"Epoch 35/100\n",
"133610/133610 [==============================] - 168s 1ms/sample - loss: 3.8751 - accuracy: 0.2332\n",
"Epoch 36/100\n",
"133610/133610 [==============================] - 166s 1ms/sample - loss: 3.8467 - accuracy: 0.2364\n",
"Epoch 37/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.8160 - accuracy: 0.2393\n",
"Epoch 38/100\n",
"133610/133610 [==============================] - 166s 1ms/sample - loss: 3.7887 - accuracy: 0.2432\n",
"Epoch 39/100\n",
"133610/133610 [==============================] - 165s 1ms/sample - loss: 3.7598 - accuracy: 0.2472\n",
"Epoch 40/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 3.7312 - accuracy: 0.2507\n",
"Epoch 41/100\n",
"133610/133610 [==============================] - 157s 1ms/sample - loss: 3.7080 - accuracy: 0.2530\n",
"Epoch 42/100\n",
"133610/133610 [==============================] - 155s 1ms/sample - loss: 3.6801 - accuracy: 0.2581\n",
"Epoch 43/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.6543 - accuracy: 0.2607\n",
"Epoch 44/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 3.6269 - accuracy: 0.2648\n",
"Epoch 45/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.6035 - accuracy: 0.2673\n",
"Epoch 46/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.5799 - accuracy: 0.2698\n",
"Epoch 47/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.5561 - accuracy: 0.2722\n",
"Epoch 48/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 3.5328 - accuracy: 0.2764\n",
"Epoch 49/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.5087 - accuracy: 0.2801\n",
"Epoch 50/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.4885 - accuracy: 0.2820\n",
"Epoch 51/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.4634 - accuracy: 0.2853\n",
"Epoch 52/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.4408 - accuracy: 0.2888\n",
"Epoch 53/100\n",
"133610/133610 [==============================] - 157s 1ms/sample - loss: 3.4214 - accuracy: 0.2916\n",
"Epoch 54/100\n",
"133610/133610 [==============================] - 158s 1ms/sample - loss: 3.3984 - accuracy: 0.2961\n",
"Epoch 55/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.3752 - accuracy: 0.2983\n",
"Epoch 56/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.3583 - accuracy: 0.3011\n",
"Epoch 57/100\n",
"133610/133610 [==============================] - 179s 1ms/sample - loss: 3.3345 - accuracy: 0.3044\n",
"Epoch 58/100\n",
"133610/133610 [==============================] - 174s 1ms/sample - loss: 3.3137 - accuracy: 0.3076\n",
"Epoch 59/100\n",
"133610/133610 [==============================] - 166s 1ms/sample - loss: 3.2953 - accuracy: 0.3096\n",
"Epoch 60/100\n",
"133610/133610 [==============================] - 169s 1ms/sample - loss: 3.2790 - accuracy: 0.3131\n",
"Epoch 61/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.2555 - accuracy: 0.3169\n",
"Epoch 62/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.2346 - accuracy: 0.3196\n",
"Epoch 63/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.2164 - accuracy: 0.3226\n",
"Epoch 64/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.2014 - accuracy: 0.3246\n",
"Epoch 65/100\n",
"133610/133610 [==============================] - 163s 1ms/sample - loss: 3.1821 - accuracy: 0.3265\n",
"Epoch 66/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 3.1594 - accuracy: 0.3306\n",
"Epoch 67/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 3.1421 - accuracy: 0.3338\n",
"Epoch 68/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.1275 - accuracy: 0.3356\n",
"Epoch 69/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 3.1063 - accuracy: 0.3388\n",
"Epoch 70/100\n",
"133610/133610 [==============================] - 161s 1ms/sample - loss: 3.0853 - accuracy: 0.3424\n",
"Epoch 71/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.0698 - accuracy: 0.3445\n",
"Epoch 72/100\n",
"133610/133610 [==============================] - 162s 1ms/sample - loss: 3.0542 - accuracy: 0.3470\n",
"Epoch 73/100\n",
"133610/133610 [==============================] - 160s 1ms/sample - loss: 3.0370 - accuracy: 0.3503\n",
"Epoch 74/100\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"133610/133610 [==============================] - 165s 1ms/sample - loss: 3.0207 - accuracy: 0.3528\n",
"Epoch 75/100\n",
"133610/133610 [==============================] - 172s 1ms/sample - loss: 3.0014 - accuracy: 0.3557\n",
"Epoch 76/100\n",
"133610/133610 [==============================] - 191s 1ms/sample - loss: 2.9868 - accuracy: 0.3584\n",
"Epoch 77/100\n",
"133610/133610 [==============================] - 228s 2ms/sample - loss: 2.9808 - accuracy: 0.3592\n",
"Epoch 78/100\n",
"133610/133610 [==============================] - 197s 1ms/sample - loss: 2.9514 - accuracy: 0.3633\n",
"Epoch 79/100\n",
"133610/133610 [==============================] - 243s 2ms/sample - loss: 2.9368 - accuracy: 0.3662\n",
"Epoch 80/100\n",
"133610/133610 [==============================] - 218s 2ms/sample - loss: 2.9171 - accuracy: 0.3693\n",
"Epoch 81/100\n",
"133610/133610 [==============================] - 208s 2ms/sample - loss: 2.9050 - accuracy: 0.3722\n",
"Epoch 82/100\n",
"133610/133610 [==============================] - 209s 2ms/sample - loss: 2.8889 - accuracy: 0.3746\n",
"Epoch 83/100\n",
"133610/133610 [==============================] - 214s 2ms/sample - loss: 2.8725 - accuracy: 0.3768\n",
"Epoch 84/100\n",
"133610/133610 [==============================] - 204s 2ms/sample - loss: 2.8564 - accuracy: 0.3818\n",
"Epoch 85/100\n",
"133610/133610 [==============================] - 196s 1ms/sample - loss: 2.8434 - accuracy: 0.3817\n",
"Epoch 86/100\n",
"133610/133610 [==============================] - 188s 1ms/sample - loss: 2.8275 - accuracy: 0.3834\n",
"Epoch 87/100\n",
"133610/133610 [==============================] - 204s 2ms/sample - loss: 2.8096 - accuracy: 0.3877\n",
"Epoch 88/100\n",
"133610/133610 [==============================] - 197s 1ms/sample - loss: 2.7986 - accuracy: 0.3893\n",
"Epoch 89/100\n",
"133610/133610 [==============================] - 199s 1ms/sample - loss: 2.7819 - accuracy: 0.3916\n",
"Epoch 90/100\n",
"133610/133610 [==============================] - 208s 2ms/sample - loss: 2.7670 - accuracy: 0.3956\n",
"Epoch 91/100\n",
"133610/133610 [==============================] - 221s 2ms/sample - loss: 2.7520 - accuracy: 0.3978\n",
"Epoch 92/100\n",
"133610/133610 [==============================] - 227s 2ms/sample - loss: 2.7356 - accuracy: 0.4006\n",
"Epoch 93/100\n",
"133610/133610 [==============================] - 226s 2ms/sample - loss: 2.7218 - accuracy: 0.4027\n",
"Epoch 94/100\n",
"133610/133610 [==============================] - 223s 2ms/sample - loss: 2.7127 - accuracy: 0.4053\n",
"Epoch 95/100\n",
"133610/133610 [==============================] - 237s 2ms/sample - loss: 2.6972 - accuracy: 0.4071\n",
"Epoch 96/100\n",
"133610/133610 [==============================] - 259s 2ms/sample - loss: 2.6771 - accuracy: 0.4102\n",
"Epoch 97/100\n",
"133610/133610 [==============================] - 237s 2ms/sample - loss: 2.6630 - accuracy: 0.4136\n",
"Epoch 98/100\n",
"133610/133610 [==============================] - 188s 1ms/sample - loss: 2.6523 - accuracy: 0.4156\n",
"Epoch 99/100\n",
"133610/133610 [==============================] - 195s 1ms/sample - loss: 2.6430 - accuracy: 0.4170\n",
"Epoch 100/100\n",
"133610/133610 [==============================] - 217s 2ms/sample - loss: 2.6220 - accuracy: 0.4204\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x7f39b7cb7ef0>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# compile model\n",
"model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
"# fit model\n",
"epochs = 100\n",
"batch_size = 128\n",
"model.fit(X, y, batch_size=batch_size, epochs=epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"save the model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"# save the model to file\n",
"model.save('model_100.h5')\n",
"# save the tokenizer\n",
"pickle.dump(tokenizer, open('tokenizer.pkl', 'wb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use the model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load the model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"lines = doc.split('\\n')\n",
"# load the model\n",
"model = tf.keras.models.load_model('model_100.h5')\n",
"# load the tokenizer\n",
"tokenizer = pickle.load(open('tokenizer.pkl', 'rb'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"text generation"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"soul i have been ashamed of you should be very beneficial to a man in your practice at the bar to be ashamed of returned sydney ought to be much obliged to shall not get off in that rejoined stryver shouldering the rejoinder at him sydney its my duty to tell\n",
"\n"
]
}
],
"source": [
"import random\n",
"# select a seed text\n",
"seed_text = lines[random.randint(0,len(lines))]\n",
"print(seed_text + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 50)\n"
]
}
],
"source": [
"encoded = np.array([tokenizer.texts_to_sequences([seed_text])[0][:50]])\n",
"print(encoded.shape)\n",
"# predict probabilities for each word\n",
"yhat = model.predict_classes(encoded, verbose=0)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# generate a sequence from a language model\n",
"def generate_seq(model, tokenizer, seq_length, seed_text, n_words):\n",
"\tresult = list()\n",
"\tin_text = seed_text\n",
"\t# generate a fixed number of words\n",
"\tfor _ in range(n_words):\n",
"\t\t# encode the text as integer\n",
"\t\tencoded = tokenizer.texts_to_sequences([in_text])[0]\n",
"\t\t# truncate sequences to a fixed length\n",
"\t\tencoded = tf.keras.preprocessing.sequence.pad_sequences([encoded], maxlen=seq_length, truncating='pre')\n",
"\t\t# predict probabilities for each word\n",
"\t\tyhat = model.predict_classes(encoded, verbose=0)\n",
"\t\t# map predicted word index to word\n",
"\t\tout_word = ''\n",
"\t\tfor word, index in tokenizer.word_index.items():\n",
"\t\t\tif index == yhat:\n",
"\t\t\t\tout_word = word\n",
"\t\t\t\tbreak\n",
"\t\t# append to input\n",
"\t\tin_text += ' ' + out_word\n",
"\t\tresult.append(out_word)\n",
"\treturn ' '.join(result)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"you said jerry i shall be able to do you call you you will be no way to you within the said mr cruncher i am sure of business yourself long importance that i am not justified in families i have my wittles is strained on the course of an\n"
]
}
],
"source": [
"# generate new text\n",
"generated = generate_seq(model, tokenizer, seq_length, seed_text, 50)\n",
"print(generated)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment