Created
July 13, 2023 09:40
-
-
Save ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943 to your computer and use it in GitHub Desktop.
s2s.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"toc_visible": true, | |
"machine_shape": "hm", | |
"gpuType": "V100", | |
"authorship_tag": "ABX9TyPw9oAVuXYy8M35S5dMFK9e", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943/s2s.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Download the dataset" | |
], | |
"metadata": { | |
"id": "zHMPQhKiAqR5" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!wget http://www.manythings.org/anki/fra-eng.zip\n", | |
"!unzip fra-eng.zip" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "74ioFGuBlF3N", | |
"outputId": "8dc4a93a-3d2a-4337-da10-1b147c9a8bbd" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"--2023-07-13 09:39:37-- http://www.manythings.org/anki/fra-eng.zip\n", | |
"Resolving www.manythings.org (www.manythings.org)... 173.254.30.110\n", | |
"Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:80... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 7420323 (7.1M) [application/zip]\n", | |
"Saving to: ‘fra-eng.zip’\n", | |
"\n", | |
"fra-eng.zip 100%[===================>] 7.08M 6.00MB/s in 1.2s \n", | |
"\n", | |
"2023-07-13 09:39:39 (6.00 MB/s) - ‘fra-eng.zip’ saved [7420323/7420323]\n", | |
"\n", | |
"Archive: fra-eng.zip\n", | |
" inflating: _about.txt \n", | |
" inflating: fra.txt \n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Imports and Setups" | |
], | |
"metadata": { | |
"id": "EN5FBnnGSv_D" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# !pip install -q wandb\n", | |
"!pip install -q keras-core" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gCNwvMzMSnvy", | |
"outputId": "64c15765-be25-4eb9-a599-2d69213740b5" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/728.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m728.0/728.0 kB\u001b[0m \u001b[31m27.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# !wandb login" | |
], | |
"metadata": { | |
"id": "VmQxlrAgsGBb" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", | |
"\n", | |
"import keras_core as keras\n", | |
"\n", | |
"import numpy as np\n", | |
"# import wandb" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "CbqUfljTRfP7", | |
"outputId": "889e42f6-eca5-4737-9c6c-eeb4cafb9ebb" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Using TensorFlow backend\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(f\"{keras.backend.backend()=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ZHIL5PloXFeL", | |
"outputId": "d23fba24-ca37-4d3c-c071-c515ac2c6678" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"keras.backend.backend()='tensorflow'\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Configurations" | |
], | |
"metadata": { | |
"id": "dPMOTsz5BuT1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"batch_size = 64 # Batch size for training.\n", | |
"epochs = 100 # Number of epochs to train for.\n", | |
"latent_dim = 256 # Latent dimensionality of the encoding space.\n", | |
"num_samples = 10000 # Number of samples to train on.\n", | |
"# Path to the data txt file on disk.\n", | |
"data_path = \"fra.txt\"\n", | |
"\n", | |
"print(f\"{data_path=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "J0S8tb7WRkDM", | |
"outputId": "5e6bc2e5-ed9f-47d0-ac90-b57c7dd37a8a" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"data_path='fra.txt'\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Dataset" | |
], | |
"metadata": { | |
"id": "BCZdVuReBw0W" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Vectorize the data.\n", | |
"input_texts = []\n", | |
"target_texts = []\n", | |
"\n", | |
"input_characters = set()\n", | |
"target_characters = set()\n", | |
"\n", | |
"with open(data_path, \"r\", encoding=\"utf-8\") as f:\n", | |
" lines = f.read().split(\"\\n\")" | |
], | |
"metadata": { | |
"id": "ObqzdjgyRqzf" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the first 5 samples\n", | |
"for line in lines[:5]:\n", | |
" print(line)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "66JFApr1lsA2", | |
"outputId": "e53285f5-b009-4ad0-d910-07dfc7efe22f" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Go.\tVa !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #1158250 (Wittydev)\n", | |
"Go.\tMarche.\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8090732 (Micsmithel)\n", | |
"Go.\tEn route !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8267435 (felix63)\n", | |
"Go.\tBouge !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #9022935 (Micsmithel)\n", | |
"Hi.\tSalut !\tCC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #509819 (Aiji)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import random\n", | |
"\n", | |
"random.shuffle(lines)" | |
], | |
"metadata": { | |
"id": "eEs3AxNXB--t" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the first 5 samples\n", | |
"for line in lines[:5]:\n", | |
" print(line)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "f0VUpQ-KCNbF", | |
"outputId": "bf4b4255-aaf8-48ec-f01e-6d55b18176fb" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"I found Tom's diary.\tJ'ai trouvé le journal intime de Tom.\tCC-BY 2.0 (France) Attribution: tatoeba.org #2327230 (CK) & #5396516 (pititnatole)\n", | |
"The people in the room all know one another.\tLes personnes dans la salle se connaissent toutes.\tCC-BY 2.0 (France) Attribution: tatoeba.org #44226 (CK) & #11274105 (lbdx)\n", | |
"How old is this zoo?\tQuel âge a ce zoo ?\tCC-BY 2.0 (France) Attribution: tatoeba.org #436249 (lukaszpp) & #590081 (qdii)\n", | |
"Do you give lessons?\tEnseignes-tu ?\tCC-BY 2.0 (France) Attribution: tatoeba.org #3151577 (CK) & #7581847 (Micsmithel)\n", | |
"We haven't lost much.\tNous n'avons pas beaucoup perdu.\tCC-BY 2.0 (France) Attribution: tatoeba.org #5789467 (CK) & #5800044 (Toynop)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for line in lines[: min(num_samples, len(lines) - 1)]:\n", | |
" input_text, target_text, _ = line.split(\"\\t\")\n", | |
" # We use \"tab\" as the \"start sequence\" character\n", | |
" # for the targets, and \"\\n\" as \"end sequence\" character.\n", | |
" target_text = \"\\t\" + target_text + \"\\n\"\n", | |
" input_texts.append(input_text)\n", | |
" target_texts.append(target_text)\n", | |
" for char in input_text:\n", | |
" if char not in input_characters:\n", | |
" input_characters.add(char)\n", | |
" for char in target_text:\n", | |
" if char not in target_characters:\n", | |
" target_characters.add(char)" | |
], | |
"metadata": { | |
"id": "zElLeLxoVrlV" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"input_characters = sorted(list(input_characters))\n", | |
"target_characters = sorted(list(target_characters))\n", | |
"num_encoder_tokens = len(input_characters)\n", | |
"num_decoder_tokens = len(target_characters)\n", | |
"max_encoder_seq_length = max([len(txt) for txt in input_texts])\n", | |
"max_decoder_seq_length = max([len(txt) for txt in target_texts])\n", | |
"\n", | |
"print(\"Number of samples:\", len(input_texts))\n", | |
"print(\"Number of unique input tokens:\", num_encoder_tokens)\n", | |
"print(\"Number of unique output tokens:\", num_decoder_tokens)\n", | |
"print(\"Max sequence length for inputs:\", max_encoder_seq_length)\n", | |
"print(\"Max sequence length for outputs:\", max_decoder_seq_length)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "kjlu6-q2mGLh", | |
"outputId": "c9515194-fb46-4dc5-9071-fd54563ead77" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Number of samples: 10000\n", | |
"Number of unique input tokens: 73\n", | |
"Number of unique output tokens: 96\n", | |
"Max sequence length for inputs: 128\n", | |
"Max sequence length for outputs: 166\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Build the input to token mapping\n", | |
"input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])\n", | |
"target_token_index = dict(\n", | |
" [(char, i) for i, char in enumerate(target_characters)]\n", | |
")" | |
], | |
"metadata": { | |
"id": "3Vp935oWm3lS" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder_input_data = np.zeros(\n", | |
" (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n", | |
" dtype=\"float32\",\n", | |
")\n", | |
"\n", | |
"# Changes (@ariG23498): Target texts length used\n", | |
"decoder_input_data = np.zeros(\n", | |
" (len(target_texts), max_decoder_seq_length, num_decoder_tokens),\n", | |
" dtype=\"float32\",\n", | |
")\n", | |
"decoder_target_data = np.zeros(\n", | |
" (len(target_texts), max_decoder_seq_length, num_decoder_tokens),\n", | |
" dtype=\"float32\",\n", | |
")" | |
], | |
"metadata": { | |
"id": "BLxf-XcIm9H_" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(f\"{encoder_input_data.shape=}\")\n", | |
"print(f\"{decoder_input_data.shape=}\")\n", | |
"print(f\"{decoder_target_data.shape=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PMG4c8kpn0Fe", | |
"outputId": "e8c10e7c-69a6-4391-adcf-b0b323997027" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"encoder_input_data.shape=(10000, 128, 73)\n", | |
"decoder_input_data.shape=(10000, 166, 96)\n", | |
"decoder_target_data.shape=(10000, 166, 96)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n", | |
" # One Hot encode the encoder input data\n", | |
" for t, char in enumerate(input_text):\n", | |
" encoder_input_data[i, t, input_token_index[char]] = 1.0\n", | |
" # Pad the rest of the places\n", | |
" encoder_input_data[i, t + 1 :, input_token_index[\" \"]] = 1.0\n", | |
"\n", | |
" # One Hot encode the decoder input data\n", | |
" for t, char in enumerate(target_text):\n", | |
" # decoder_target_data is ahead of decoder_input_data by one timestep\n", | |
" decoder_input_data[i, t, target_token_index[char]] = 1.0\n", | |
" if t > 0:\n", | |
" # decoder_target_data will be ahead by one timestep\n", | |
" # and will not include the start character.\n", | |
" decoder_target_data[i, t - 1, target_token_index[char]] = 1.0\n", | |
" # Pad the rest of the places\n", | |
" decoder_input_data[i, t + 1 :, target_token_index[\" \"]] = 1.0\n", | |
" decoder_target_data[i, t:, target_token_index[\" \"]] = 1.0" | |
], | |
"metadata": { | |
"id": "ceNRvi8VmquY" | |
}, | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Define an input sequence and process it.\n", | |
"encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))\n", | |
"encoder = keras.layers.LSTM(latent_dim, return_state=True)\n", | |
"encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n", | |
"\n", | |
"print(f\"{encoder_outputs.shape=}\")\n", | |
"print(f\"{state_h.shape=}\")\n", | |
"print(f\"{state_c.shape=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "BEf_ZhvpRud3", | |
"outputId": "4d0fba81-6b12-4cb1-a89f-c128575e687d" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"encoder_outputs.shape=(None, 256)\n", | |
"state_h.shape=(None, 256)\n", | |
"state_c.shape=(None, 256)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# We discard `encoder_outputs` and only keep the states.\n", | |
"encoder_states = [state_h, state_c]" | |
], | |
"metadata": { | |
"id": "BpWdTvEYpYK_" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set up the decoder, using `encoder_states` as initial state.\n", | |
"decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))\n", | |
"\n", | |
"# We set up our decoder to return full output sequences,\n", | |
"# and to return internal states as well. We don't use the\n", | |
"# return states in the training model, but we will use them in inference.\n", | |
"decoder_lstm = keras.layers.LSTM(\n", | |
" latent_dim, return_sequences=True, return_state=True\n", | |
")\n", | |
"decoder_outputs, _, _ = decoder_lstm(\n", | |
" decoder_inputs, initial_state=encoder_states\n", | |
")\n", | |
"decoder_dense = keras.layers.Dense(num_decoder_tokens)\n", | |
"decoder_outputs = decoder_dense(decoder_outputs)" | |
], | |
"metadata": { | |
"id": "LpZqFrjRpnzn" | |
}, | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(f\"{decoder_outputs.shape=}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PsQJKEbfpxpv", | |
"outputId": "8f6749ac-80fa-492c-cc5e-29f544c40d27" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"decoder_outputs.shape=(None, None, 96)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Define the model that will turn\n", | |
"# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`\n", | |
"model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)\n", | |
"\n", | |
"model.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 417 | |
}, | |
"id": "UMmlicLgpruh", | |
"outputId": "049656e4-7aec-422e-bca2-3136ef21a6b9" | |
}, | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1mModel: \"functional_1\"\u001b[0m\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_1\"</span>\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m73\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ input_layer_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n", | |
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m337,920\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n", | |
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ lstm_1 (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m361,472\u001b[0m │ input_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n", | |
"│ │ \u001b[38;5;34m256\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n", | |
"│ │ \u001b[38;5;34m256\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n", | |
"│ │ \u001b[38;5;34m256\u001b[0m)] │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m24,672\u001b[0m │ lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n", | |
"└─────────────────────┴───────────────────┴─────────┴──────────────────────┘\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n", | |
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n", | |
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n", | |
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">73</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ input_layer_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n", | |
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">337,920</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ │\n", | |
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ lstm_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">361,472</span> │ input_layer_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n", | |
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n", | |
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n", | |
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n", | |
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n", | |
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">24,672</span> │ lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n", | |
"└─────────────────────┴───────────────────┴─────────┴──────────────────────┘\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m724,064\u001b[0m (22.10 MB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">724,064</span> (22.10 MB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m724,064\u001b[0m (22.10 MB)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">724,064</span> (22.10 MB)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" | |
], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n", | |
"</pre>\n" | |
] | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# wandb.init(\n", | |
"# entity=\"ariG23498\",\n", | |
"# project=\"s2s\",\n", | |
"# config={\n", | |
"# \"backend\": keras.backend.backend(),\n", | |
"# \"batch_size\": batch_size,\n", | |
"# \"epochs\": epochs,\n", | |
"# \"latent_dim\": latent_dim,\n", | |
"# \"num_samples\": num_samples,\n", | |
"# }\n", | |
"# )" | |
], | |
"metadata": { | |
"id": "4zbWNn46tBHQ" | |
}, | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"id": "GIEBcYexRZ8k", | |
"outputId": "98a98cae-0adb-49e0-9492-5ab846bc671f" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch 1/100\n" | |
] | |
}, | |
{ | |
"output_type": "error", | |
"ename": "InvalidArgumentError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-23-aedeb93827cb>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m )\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m model.fit(\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mencoder_input_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecoder_input_data\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mdecoder_target_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;31m# `keras_core.config.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0m\u001b[1;32m 53\u001b[0m inputs, attrs, num_outputs)\n\u001b[1;32m 54\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mInvalidArgumentError\u001b[0m: Graph execution error:\n\nDetected unsupported operations when trying to compile graph __inference_one_step_on_data_2665[] on XLA_GPU_JIT: CudnnRNN (No registered 'CudnnRNN' OpKernel for XLA_GPU_JIT devices compatible with node {{node functional_1/lstm/CudnnRNN}}){{node functional_1/lstm/CudnnRNN}}\nThe op is created at: \nFile \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n return _run_code(code, main_globals, None,\nFile \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n exec(code, run_globals)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py\", line 16, in <module>\n app.launch_new_instance()\nFile \"/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py\", line 992, in launch_instance\n app.start()\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py\", line 619, in start\n self.io_loop.start()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py\", line 195, in start\n self.asyncio_loop.run_forever()\nFile \"/usr/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n self._run_once()\nFile \"/usr/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n handle._run()\nFile \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\", line 685, in <lambda>\n lambda f: self._run_callback(functools.partial(callback, future))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\", line 738, in _run_callback\n ret = callback()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 825, in inner\n self.ctx_run(self.run)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 786, in run\n yielded = self.gen.send(value)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 377, in dispatch_queue\n yield self.process_one()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 250, in wrapper\n runner = Runner(ctx_run, result, future, yielded)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 748, in __init__\n self.ctx_run(self.run)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 786, in run\n yielded = self.gen.send(value)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 361, in process_one\n yield gen.maybe_future(dispatch(*args))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 261, in dispatch_shell\n yield gen.maybe_future(handler(stream, idents, msg))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 539, in execute_request\n self.do_execute(\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py\", line 302, in do_execute\n res = shell.run_cell(code, store_history=store_history, silent=silent)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py\", line 539, in run_cell\n return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 2975, in run_cell\n result = self._run_cell(\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3030, in _run_cell\n return runner(coro)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py\", line 78, in _pseudo_sync_runner\n coro.send(None)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3257, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3473, in run_ast_nodes\n if (await self.run_code(code, result, async_=asy)):\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\nFile \"<ipython-input-23-aedeb93827cb>\", line 7, in <cell line: 7>\n model.fit(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 306, in fit\n logs = self.train_function(iterator)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 111, in one_step_on_iterator\n outputs = self.distribute_strategy.run(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 98, in one_step_on_data\n return self.train_step(data)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 51, in train_step\n y_pred = self(x, training=True)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py\", line 703, in __call__\n outputs = super().__call__(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/operation.py\", line 41, in __call__\n return call_fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 154, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/models/functional.py\", line 181, in call\n outputs = self._run_through_graph(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/function.py\", line 127, in _run_through_graph\n outputs = operation_fn(node.operation)(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/models/functional.py\", line 549, in call\n return operation(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py\", line 703, in __call__\n outputs = super().__call__(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/operation.py\", line 41, in __call__\n return call_fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 154, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/lstm.py\", line 526, in call\n return super().call(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/rnn.py\", line 390, in call\n last_output, outputs, states = self.inner_loop(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/lstm.py\", line 505, in inner_loop\n return backend.lstm(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/rnn.py\", line 815, in lstm\n return _cudnn_lstm(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/rnn.py\", line 926, in _cudnn_lstm\n outputs, h, c, _ = tf.raw_ops.CudnnRNN(\n\t [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_2730]" | |
] | |
} | |
], | |
"source": [ | |
"model.compile(\n", | |
" optimizer=keras.optimizers.AdamW(3e-3),\n", | |
" loss=keras.losses.CategoricalCrossentropy(from_logits=True),\n", | |
" metrics=[\"accuracy\"],\n", | |
")\n", | |
"\n", | |
"model.fit(\n", | |
" [encoder_input_data, decoder_input_data],\n", | |
" decoder_target_data,\n", | |
" batch_size=batch_size,\n", | |
" epochs=epochs,\n", | |
" validation_split=0.2,\n", | |
" # callbacks=[wandb.keras.WandbMetricsLogger()],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# wandb.finish()" | |
], | |
"metadata": { | |
"id": "J2XmgVpPuIc5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Save model\n", | |
"model.save(\"s2s_model.keras\")" | |
], | |
"metadata": { | |
"id": "p1BFMAFjR2Yo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Define sampling models\n", | |
"# Restore the model and construct the encoder and decoder.\n", | |
"model = keras.models.load_model(\"s2s_model.keras\")" | |
], | |
"metadata": { | |
"id": "jVMW4G97R6H-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder_inputs = model.input[0] # input_1\n", | |
"encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1\n", | |
"encoder_states = [state_h_enc, state_c_enc]\n", | |
"encoder_model = keras.Model(encoder_inputs, encoder_states)\n", | |
"\n", | |
"decoder_inputs = model.input[1] # input_2\n", | |
"decoder_state_input_h = keras.Input(shape=(latent_dim,))\n", | |
"decoder_state_input_c = keras.Input(shape=(latent_dim,))\n", | |
"decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n", | |
"decoder_lstm = model.layers[3]\n", | |
"decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(\n", | |
" decoder_inputs, initial_state=decoder_states_inputs\n", | |
")\n", | |
"decoder_states = [state_h_dec, state_c_dec]\n", | |
"decoder_dense = model.layers[4]\n", | |
"decoder_outputs = decoder_dense(decoder_outputs)\n", | |
"decoder_model = keras.Model(\n", | |
" [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states\n", | |
")" | |
], | |
"metadata": { | |
"id": "z5p0sER-R8rx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Reverse-lookup token index to decode sequences back to\n", | |
"# something readable.\n", | |
"reverse_input_char_index = dict(\n", | |
" (i, char) for char, i in input_token_index.items()\n", | |
")\n", | |
"reverse_target_char_index = dict(\n", | |
" (i, char) for char, i in target_token_index.items()\n", | |
")" | |
], | |
"metadata": { | |
"id": "ZAPRidq5R-rW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def decode_sequence(input_seq):\n", | |
" # Encode the input as state vectors.\n", | |
" states_value = encoder_model.predict(input_seq, verbose=0)\n", | |
"\n", | |
" # Generate empty target sequence of length 1.\n", | |
" target_seq = np.zeros((1, 1, num_decoder_tokens))\n", | |
" # Populate the first character of target sequence with the start character.\n", | |
" target_seq[0, 0, target_token_index[\"\\t\"]] = 1.0\n", | |
"\n", | |
" # Sampling loop for a batch of sequences\n", | |
" # (to simplify, here we assume a batch of size 1).\n", | |
" stop_condition = False\n", | |
" decoded_sentence = \"\"\n", | |
" while not stop_condition:\n", | |
" output_tokens, h, c = decoder_model.predict(\n", | |
" [target_seq] + states_value, verbose=0\n", | |
" )\n", | |
"\n", | |
" # Sample a token\n", | |
" sampled_token_index = np.argmax(output_tokens[0, -1, :])\n", | |
" sampled_char = reverse_target_char_index[sampled_token_index]\n", | |
" decoded_sentence += sampled_char\n", | |
"\n", | |
" # Exit condition: either hit max length\n", | |
" # or find stop character.\n", | |
" if (\n", | |
" sampled_char == \"\\n\"\n", | |
" or len(decoded_sentence) > max_decoder_seq_length\n", | |
" ):\n", | |
" stop_condition = True\n", | |
"\n", | |
" # Update the target sequence (of length 1).\n", | |
" target_seq = np.zeros((1, 1, num_decoder_tokens))\n", | |
" target_seq[0, 0, sampled_token_index] = 1.0\n", | |
"\n", | |
" # Update states\n", | |
" states_value = [h, c]\n", | |
" return decoded_sentence" | |
], | |
"metadata": { | |
"id": "9S4PEz68SCHI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for seq_index in range(20):\n", | |
" # Take one sequence (part of the training set)\n", | |
" # for trying out decoding.\n", | |
" input_seq = encoder_input_data[seq_index : seq_index + 1]\n", | |
" decoded_sentence = decode_sequence(input_seq)\n", | |
" print(\"-\")\n", | |
" print(\"Input sentence:\", input_texts[seq_index])\n", | |
" print(\"Decoded sentence:\", decoded_sentence)" | |
], | |
"metadata": { | |
"id": "Vb_sRnYrRzqK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment