Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save DerrikMilligan/b8a882efacb07120b1a5abed22d8bfd2 to your computer and use it in GitHub Desktop.
Save DerrikMilligan/b8a882efacb07120b1a5abed22d8bfd2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "gSW4SBNiKNjK"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import tensorflow.keras.backend as K\n",
"import numpy as np\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras import layers, Model\n",
"import os\n",
"from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
"import string\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "YWZkhNsB9HkQ"
},
"outputs": [],
"source": [
"def save_dataset(dataset,fileName):\n",
" path = os.path.join('./tfDatasets/', fileName)\n",
" tf.data.experimental.save(dataset, path)\n",
"\n",
"def load_dataset(fileName):\n",
" path = os.path.join(\"./tfDatasets/\", fileName)\n",
" new_dataset = tf.data.experimental.load(path,\n",
" tf.TensorSpec(shape=(), dtype=tf.string))\n",
" return new_dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6kfAU3OORcD2",
"outputId": "9deb8276-2301-4484-de0c-ab7adf48a10a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Length of text: 1212934 characters\n"
]
}
],
"source": [
"path_to_file2 = tf.keras.utils.get_file('2701-0.txt', 'https://www.gutenberg.org/files/2701/2701-0.txt')\n",
"text = open(path_to_file2, 'rb').read().decode(encoding='utf-8')\n",
"\n",
"# Remove some of the beginning and ending crap\n",
"text = 'CHAPTER 1. Loomins.\\n\\nCall me Ishmael' + text.split(\"Call me Ishmael\")[1]\n",
"text = text.split(\"*** END OF THE PROJECT GUTENBERG EBOOK MOBY-DICK; OR THE WHALE ***\")[0]\n",
"\n",
"# Clean up some characters for consistency\n",
"text = text.replace('“','\"')\n",
"text = text.replace('”','\"')\n",
"text = text.replace('£','E')\n",
"text = text.replace('é','E')\n",
"text = text.replace('â','a')\n",
"text = text.replace('æ','ae')\n",
"text = text.replace('è','e')\n",
"text = text.replace('œ','oe')\n",
"text = text.replace('‘',\"'\")\n",
"text = text.replace('’',\"'\")\n",
"\n",
"text = text.lower()\n",
"\n",
"print('Length of text: {} characters'.format(len(text)))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-0zg8dmexEwa",
"outputId": "16295e64-73ee-466c-f3fc-49dde2b7dfe7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total disctinct chars: 57\n"
]
}
],
"source": [
"chars = sorted(list(set(text)))\n",
"print(\"Total disctinct chars:\", len(chars))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "yXpEadsc-g4i"
},
"outputs": [],
"source": [
"# cut the text in semi-redundant sequences of maxlen characters\n",
"maxlen = 20\n",
"step = 3\n",
"input_chars = []\n",
"next_char = []\n",
"\n",
"for i in range(0, len(text) - maxlen, step):\n",
" input_chars.append(text[i : i + maxlen])\n",
" next_char.append(text[i + maxlen])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_JW0DEHgHj6x",
"outputId": "c889fa33-49b9-4067-adf6-49ba1becfbc4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of sequences: 404305\n",
"input X (input_chars) ---> output y (next_char) \n",
"chapter 1. loomins.\n",
" ---> \n",
"\n",
"pter 1. loomins.\n",
"\n",
"ca ---> l\n",
"r 1. loomins.\n",
"\n",
"call ---> m\n",
". loomins.\n",
"\n",
"call me ---> i\n",
"oomins.\n",
"\n",
"call me ish ---> m\n"
]
}
],
"source": [
"print(\"Number of sequences:\", len(input_chars))\n",
"print(\"input X (input_chars) ---> output y (next_char) \")\n",
"\n",
"for i in range(5):\n",
" print( input_chars[i],\" ---> \", next_char[i])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "yJqnQbP9_VtV"
},
"outputs": [],
"source": [
"X_train_ds_raw = tf.data.Dataset.from_tensor_slices(input_chars)\n",
"y_train_ds_raw = tf.data.Dataset.from_tensor_slices(next_char)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "wbaD2D2g9Um0"
},
"outputs": [],
"source": [
"batch_size = 256\n",
"max_features = len(chars)\n",
"embedding_dim = 256\n",
"sequence_length = maxlen\n",
"\n",
"def char_split(input_data):\n",
" return tf.strings.unicode_split(input_data, 'UTF-8')\n",
"\n",
"# def word_split(input_data):\n",
"# return tf.strings.split(input_data)\n",
"\n",
"vectorize_layer = TextVectorization(\n",
" max_tokens=max_features,\n",
" split=char_split, # word_split or char_split\n",
" output_mode=\"int\",\n",
" output_sequence_length=sequence_length,\n",
")\n",
"\n",
"vectorize_layer.adapt(X_train_ds_raw.batch(batch_size))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qelaS3UgBijo",
"outputId": "9b1b3d58-0e01-46ef-866f-7da4b39faf61"
},
"outputs": [
{
"data": {
"text/plain": [
"(TensorSpec(shape=(20,), dtype=tf.int64, name=None),\n",
" TensorSpec(shape=(20,), dtype=tf.int64, name=None))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def vectorize_text(text):\n",
" text = tf.expand_dims(text, -1)\n",
" return tf.squeeze(vectorize_layer(text))\n",
"\n",
"# Vectorize the data.\n",
"X_train_ds = X_train_ds_raw.map(vectorize_text)\n",
"y_train_ds = y_train_ds_raw.map(vectorize_text)\n",
"\n",
"X_train_ds.element_spec, y_train_ds.element_spec"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "g0_gdj6OxWuP"
},
"outputs": [],
"source": [
"AUTOTUNE = tf.data.AUTOTUNE\n",
"y_train_ds = y_train_ds.map(lambda x: x[0])\n",
"train_ds = tf.data.Dataset.zip((X_train_ds,y_train_ds))\n",
"train_ds = train_ds.shuffle(buffer_size=512).batch(batch_size, drop_remainder=True).cache().prefetch(buffer_size=AUTOTUNE)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uc6K8t0q4r50",
"outputId": "da68eeda-3ad0-4aad-e835-75d54e58e9a4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input (X) dimension: (256, 20) \n",
"output (y) dimension: (256,)\n"
]
}
],
"source": [
"for sample in train_ds.take(1):\n",
" print(\"input (X) dimension: \", sample[0].numpy().shape, \"\\noutput (y) dimension: \",sample[1].numpy().shape)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Pvvz6xIA6ihY",
"outputId": "1b1e678f-1673-4a25-b3a3-e07f319a90a9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"input (sequence of chars): [ 2 15 3 2 4 10 5 4 2 8 4 2 11 3 28 14 8 11 3 0] \n",
"output (next char to complete the input): 9\n"
]
}
],
"source": [
"for sample in train_ds.take(1):\n",
" print(\"input (sequence of chars): \", sample[0][0].numpy(), \"\\noutput (next char to complete the input): \",sample[1][0].numpy())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yBSW8yTj2ugw",
"outputId": "8014b3d9-1290-47ed-e465-a0bf5477aa50"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t lourish cato throws \n",
"input (sequence of chars): lourish cato throws \n",
"output (next char to complete the input): h\n"
]
}
],
"source": [
"def decode_sequence(encoded_sequence):\n",
" deceoded_sequence=[]\n",
" for token in encoded_sequence:\n",
" deceoded_sequence.append(vectorize_layer.get_vocabulary()[token])\n",
" \n",
" sequence= ''.join(deceoded_sequence)\n",
" print(\"\\t\",sequence)\n",
" return sequence\n",
"\n",
"for sample in train_ds.take(1):\n",
" print(\"input (sequence of chars): \", decode_sequence(sample[0][0].numpy()), \"\\noutput (next char to complete the input): \",vectorize_layer.get_vocabulary()[sample[1][0].numpy()])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "eSkEM7B1erC6"
},
"outputs": [],
"source": [
"def softmax(z):\n",
" return np.exp(z) / sum(np.exp(z))\n",
"\n",
"def greedy_search(conditional_probability):\n",
" return (np.argmax(conditional_probability))\n",
"\n",
"def temperature_sampling (conditional_probability, temperature=1.0):\n",
" conditional_probability = np.asarray(conditional_probability).astype(\"float64\")\n",
" conditional_probability = np.log(conditional_probability) / temperature\n",
" reweighted_conditional_probability = softmax(conditional_probability)\n",
" probas = np.random.multinomial(1, reweighted_conditional_probability, 1)\n",
" return np.argmax(probas)\n",
"\n",
"def top_k_sampling(conditional_probability, k):\n",
" top_k_probabilities, top_k_indices= tf.math.top_k(conditional_probability, k=k, sorted=True)\n",
" top_k_probabilities = np.asarray(top_k_probabilities).astype(\"float32\")\n",
" top_k_probabilities = np.squeeze(top_k_probabilities)\n",
" top_k_indices = np.asarray(top_k_indices).astype(\"int32\")\n",
" top_k_redistributed_probability = softmax(top_k_probabilities)\n",
" top_k_redistributed_probability = np.asarray(top_k_redistributed_probability).astype(\"float32\")\n",
" sampled_token = np.random.choice(np.squeeze(top_k_indices), p=top_k_redistributed_probability)\n",
" return sampled_token"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "dsOKHPLbOwn2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" embedding (Embedding) (None, 20, 256) 14592 \n",
" \n",
" dropout (Dropout) (None, 20, 256) 0 \n",
" \n",
" lstm (LSTM) (None, 20, 128) 197120 \n",
" \n",
" lstm_1 (LSTM) (None, 64) 49408 \n",
" \n",
" flatten (Flatten) (None, 64) 0 \n",
" \n",
" dense (Dense) (None, 57) 3705 \n",
" \n",
"=================================================================\n",
"Total params: 264,825\n",
"Trainable params: 264,825\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"model_LSTM = keras.Sequential([\n",
" keras.Input(shape=(sequence_length), dtype=\"int64\"),\n",
" layers.Embedding(max_features, embedding_dim),\n",
" layers.Dropout(0.5),\n",
" layers.LSTM(128, return_sequences=True),\n",
" layers.LSTM(64),\n",
" layers.Flatten(),\n",
" layers.Dense(max_features, activation='softmax'),\n",
"])\n",
"\n",
"# lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n",
"# initial_learning_rate=0.1,\n",
"# decay_steps=200,\n",
"# decay_rate=0.9\n",
"# )\n",
"\n",
"# optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)\n",
"optimizer = keras.optimizers.RMSprop(learning_rate=0.01)\n",
"\n",
"model_LSTM.compile(\n",
" optimizer=optimizer,\n",
" loss='sparse_categorical_crossentropy', \n",
" metrics=['accuracy']\n",
")\n",
"\n",
"print(model_LSTM.summary())"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4HNYURebUBt3",
"outputId": "3e080bd8-f453-4bac-efe2-1e280acd4ab0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1579/1579 [==============================] - 226s 141ms/step - loss: 1.9880 - accuracy: 0.4050\n",
"Epoch 2/5\n",
"1579/1579 [==============================] - 254s 161ms/step - loss: 1.7793 - accuracy: 0.4560\n",
"Epoch 3/5\n",
"1579/1579 [==============================] - 231s 146ms/step - loss: 1.7495 - accuracy: 0.4646\n",
"Epoch 4/5\n",
"1579/1579 [==============================] - 226s 143ms/step - loss: 1.7373 - accuracy: 0.4676\n",
"Epoch 5/5\n",
"1579/1579 [==============================] - 234s 148ms/step - loss: 1.7299 - accuracy: 0.4696\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x25d7a379910>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_LSTM.fit(train_ds, epochs=5)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "cmu8JbXP13kG"
},
"outputs": [],
"source": [
"# skip_ids = self.ids_from_chars(['','[UNK]'])\n",
"\n",
"def generate_text(model, seed_original, step):\n",
" output = []\n",
" \n",
" seed= vectorize_text(seed_original)\n",
" output.append(\"The prompt is\")\n",
" output.append(decode_sequence(seed.numpy().squeeze()))\n",
"\n",
" seed= vectorize_text(seed_original).numpy().reshape(1,-1)\n",
" #Text Generated by Greedy Search Sampling\n",
" generated_greedy_search = (seed)\n",
" for i in range(step):\n",
" predictions = model.predict(seed, verbose=1)\n",
" next_index = greedy_search(predictions.squeeze())\n",
" generated_greedy_search = np.append(generated_greedy_search, next_index)\n",
" seed = generated_greedy_search[-sequence_length:].reshape(1,sequence_length)\n",
"\n",
" output.append(\"Text Generated by Greedy Search Sampling:\")\n",
" output.append(decode_sequence(generated_greedy_search))\n",
"\n",
" #Text Generated by Temperature Sampling\n",
" output.append(\"Text Generated by Temperature Sampling:\")\n",
" for temperature in [0.2, 0.5, 1.0, 1.2]:\n",
" output.append(\"\\ttemperature: {}\".format(temperature))\n",
" seed= vectorize_text(seed_original).numpy().reshape(1,-1)\n",
" generated_temperature = (seed)\n",
"\n",
" for i in range(step):\n",
" predictions = model.predict(seed, verbose=1);\n",
" next_index = temperature_sampling(predictions.squeeze(), temperature)\n",
" generated_temperature = np.append(generated_temperature, next_index)\n",
" seed = generated_temperature[-sequence_length:].reshape(1,sequence_length)\n",
"\n",
" output.append(decode_sequence(generated_temperature))\n",
"\n",
" #Text Generated by Top-K Sampling\n",
" output.append(\"Text Generated by Top-K Sampling:\")\n",
" for k in [2, 3, 4, 5]:\n",
" print(\"\\tTop-k: \", k)\n",
" seed = vectorize_text(seed_original).numpy().reshape(1,-1)\n",
" generated_top_k = (seed)\n",
"\n",
" for i in range(step):\n",
" predictions = model.predict(seed, verbose=1);\n",
" next_index = top_k_sampling(predictions.squeeze(), k)\n",
" generated_top_k = np.append(generated_top_k, next_index)\n",
" seed = generated_top_k[-sequence_length:].reshape(1,sequence_length)\n",
"\n",
" output.append(decode_sequence(generated_top_k))\n",
" \n",
" return output\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fILBboHX4r6O",
"outputId": "76e38343-3bd1-4bfe-b36f-e77e0f591864"
},
"outputs": [],
"source": [
"%%capture\n",
"response = generate_text(model_LSTM,\"I observed, however\", 100)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The prompt is\n",
"i observed however\n",
"Text Generated by Greedy Search Sampling:\n",
"i observed however the ship the ship the ship the ship the ship the ship the ship the ship the ship the ship\n",
"Text Generated by Temperature Sampling:\n",
"\ttemperature: 0.2\n",
"i observed however the sea the ship the whale the ship and flage to the sharks of the ship a ship a strange and a\n",
"\ttemperature: 0.5\n",
"i observed however beat struck still deck and before a last cart of mind of the crusy of the ship the sea how and \n",
"\ttemperature: 1.0\n",
"i observed however mume of good keen at minking deam upon\n",
"the bropes\n",
"him halping bout into his heade worlds of hi\n",
"\ttemperature: 1.2\n",
"i observed howevernical setsar esch this\n",
"helm betidelamlanged by then mouth\n",
"hands marks breilousybows\n",
"som\n",
"Text Generated by Top-K Sampling:\n",
"i observed however hane to shand the sea the sharks strungs tower shanks strack starbounty ship that to st\n",
"cort here and ovoy that trumbent thict anso astares time thick of\n",
"a harrow hims foll to \n",
"fromserved howevers beneitytelest hearts\n",
"mank inthocted a that any\n",
"but\n",
" to hiuek\n",
"i observed howevere\n",
"chrangicabessind tusigg in his eswitceners\n"
]
}
],
"source": [
"for line in response:\n",
" print(line)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"qegEhDUZDofO",
"AvUeG0p12RfU",
"AhBCzBJr5EEx"
],
"name": "Char Level Text Generation with an LSTM Model.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment