Created
August 30, 2020 05:59
-
-
Save sayan1999/d008ef965c72371602c399284b7ab189 to your computer and use it in GitHub Desktop.
word_seq2seq-extended.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"colab": { | |
"name": "word_seq2seq-extended.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"toc_visible": true, | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/sayan1999/d008ef965c72371602c399284b7ab189/word_seq2seq-extended.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zMy7gQFI9qXy", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 900 | |
}, | |
"outputId": "a69f4c6c-39b4-4dc4-b481-ffd690d9b25a" | |
}, | |
"source": [ | |
"!pip install bidict\n", | |
"!pip install pixiedust\n", | |
"import nltk\n", | |
"nltk.download('punkt')" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting bidict\n", | |
" Downloading https://files.pythonhosted.org/packages/7a/7a/1fcfc397e61b22091267aa767266d8ab200a00b7dbf3aadead7fd41a74b9/bidict-0.21.0-py2.py3-none-any.whl\n", | |
"Installing collected packages: bidict\n", | |
"Successfully installed bidict-0.21.0\n", | |
"Collecting pixiedust\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/16/ba/7488f06b48238205562f9d63aaae2303c060c5dfd63b1ddd3bd9d4656eb1/pixiedust-1.1.18.tar.gz (197kB)\n", | |
"\u001b[K |████████████████████████████████| 204kB 8.7MB/s \n", | |
"\u001b[?25hCollecting mpld3\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/66/31/89bd2afd21b920e3612996623e7b3aac14d741537aa77600ea5102a34be0/mpld3-0.5.1.tar.gz (1.0MB)\n", | |
"\u001b[K |████████████████████████████████| 1.0MB 16.7MB/s \n", | |
"\u001b[?25hRequirement already satisfied: lxml in /usr/local/lib/python3.6/dist-packages (from pixiedust) (4.2.6)\n", | |
"Collecting geojson\n", | |
" Downloading https://files.pythonhosted.org/packages/e4/8d/9e28e9af95739e6d2d2f8d4bef0b3432da40b7c3588fbad4298c1be09e48/geojson-2.5.0-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: astunparse in /usr/local/lib/python3.6/dist-packages (from pixiedust) (1.6.3)\n", | |
"Requirement already satisfied: markdown in /usr/local/lib/python3.6/dist-packages (from pixiedust) (3.2.2)\n", | |
"Collecting colour\n", | |
" Downloading https://files.pythonhosted.org/packages/74/46/e81907704ab203206769dee1385dc77e1407576ff8f50a0681d0a6b541be/colour-0.1.5-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pixiedust) (2.23.0)\n", | |
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from mpld3->pixiedust) (2.11.2)\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from mpld3->pixiedust) (3.2.2)\n", | |
"Requirement already satisfied: six<2.0,>=1.6.1 in /usr/local/lib/python3.6/dist-packages (from astunparse->pixiedust) (1.15.0)\n", | |
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.6/dist-packages (from astunparse->pixiedust) (0.35.1)\n", | |
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown->pixiedust) (1.7.0)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (1.24.3)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (2.10)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (2020.6.20)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (3.0.4)\n", | |
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->mpld3->pixiedust) (1.1.1)\n", | |
"Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (1.18.5)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (2.8.1)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (0.10.0)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (1.2.0)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (2.4.7)\n", | |
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown->pixiedust) (3.1.0)\n", | |
"Building wheels for collected packages: pixiedust, mpld3\n", | |
" Building wheel for pixiedust (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pixiedust: filename=pixiedust-1.1.18-cp36-none-any.whl size=321727 sha256=aca85894b80a6fe25fb05217aa1095c664d98c2cb4fdd4eb17715d867b26db89\n", | |
" Stored in directory: /root/.cache/pip/wheels/e8/b1/86/c2f2e16e6bf9bfe556f9dbf8adb9f41816c476d73078c7d0eb\n", | |
" Building wheel for mpld3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for mpld3: filename=mpld3-0.5.1-cp36-none-any.whl size=364064 sha256=6ca44fc5e92e2085a5b69fe24c1291795c168a4c20cd1f7e5e4d6066281036b0\n", | |
" Stored in directory: /root/.cache/pip/wheels/38/68/06/d119af6c3f9a2d1e123c1f72d276576b457131b3a7bf94e402\n", | |
"Successfully built pixiedust mpld3\n", | |
"Installing collected packages: mpld3, geojson, colour, pixiedust\n", | |
"Successfully installed colour-0.1.5 geojson-2.5.0 mpld3-0.5.1 pixiedust-1.1.18\n", | |
"[nltk_data] Downloading package punkt to /root/nltk_data...\n", | |
"[nltk_data] Unzipping tokenizers/punkt.zip.\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 1 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aLt_9sNL9qX9", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import re, os, difflib\n", | |
"from matplotlib import pyplot as plt\n", | |
"from sklearn.utils import shuffle\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from tensorflow.keras.layers import Input, LSTM, Embedding, Dense, Attention, Bidirectional, Concatenate\n", | |
"from tensorflow.keras.models import Model, Sequential\n", | |
"import tensorflow as tf\n", | |
"from nltk import word_tokenize\n", | |
"from gensim.models import Word2Vec\n", | |
"from sklearn.preprocessing import OneHotEncoder\n", | |
"from bidict import bidict\n", | |
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | |
"from sklearn.metrics import classification_report\n", | |
"from tensorflow.keras.callbacks import TensorBoard\n", | |
"from tensorflow.keras.utils import plot_model\n", | |
"from tensorflow.keras import backend as K" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "osRUcUVb9qYA", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 92 | |
}, | |
"outputId": "a6171656-ac7d-4778-8cd5-9cb57693895e" | |
}, | |
"source": [ | |
"# parameters\n", | |
"# if running on colab turn this false, and select GPU runtime\n", | |
"batch_size=32 if not tf.test.is_gpu_available() else 256\n", | |
"colab=True\n", | |
"training=True\n", | |
"validation=True\n", | |
"ctx_vec_len=128\n", | |
"embedding_dim=128\n", | |
"epochs=25\n", | |
"# either length or list of index such as range(1, 2200)\n", | |
"training_samples=100\n", | |
"dropout=0.2\n", | |
"weight_file='word-seq2seq.hdf5'" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"WARNING:tensorflow:From <ipython-input-3-3badf1bfea45>:3: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n", | |
"Instructions for updating:\n", | |
"Use `tf.config.list_physical_devices('GPU')` instead.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wf7r24jQ9qYH", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 225 | |
}, | |
"outputId": "71b10c13-8b88-4a4b-e0c2-fc1267cf6437" | |
}, | |
"source": [ | |
"from IPython.display import display, Markdown\n", | |
"if not colab:\n", | |
" display(Markdown('''## Architecture For Neural Machine Trans\n", | |
"![Architecture Neural Machine Trans](image/NeuralMachineTrans.jpg)'''))\n", | |
" \n", | |
"else:\n", | |
" display(Markdown('''## Architecture For Neural Machine Trans\n", | |
"![Architecture Neural Machine Trans](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAVIAAACVCAMAAAA9kYJlAAABa1BMVEX///+S0FBEcsQAsPAAAACU004oW638/Py+vr7Pz8/39/f/3ZrY2ur/0GcAs/Pa3uYAb7OU1EhrpF1BccagqsI4a8AAgMP/7s//vQDl5vLv7+9tiZzs6+/j4+NhmmJNg2t1jKVGe26DwU+Ol7W2vc/IztopVZ1VcKd+u1LZ2dmMlsDKzeJzrlmutNKhoaG3vNdRUVFtbW1eXl59fX1qamq1tbWmpqZ3d3eLi4tHR0c3NzeXl5e5ubklJSUxMTEfHx9fg4sTExP/xiH/9eMyvPL/8s3/9+L/6sB6j7dge64AnN85X6Pf9P3/zVr/3oz/2G//56j/y0X/13dJaKRLS1KusL1FR2AAADJTc3t+mKomOyBSf0phlG9gjIORortkmWkARplRg7ZcgzJEYCVUg3yK0va8wc5/hI4SHjMJEiEwUIo6UiN2qEAxMlI6VzolQHC14vljyPUzLzgeN2JqgqQvRz9hYnOJj6xfZYZlJtHGAAAShElEQVR4nO2djX/UNprHBShYErRsgGyCiCgNoaCD9fpNfpM9w0AbaHkphe217NH29qXs7l332u719u7Pv0eemTAvdqwkkwSCfx8YO/YzsvT1oxd79IJQp06dOnXq1KlTp7dFceIfdRSOmUKNuJr429UTf0R0/gt9dtBRessleuaTpUng8CSMCM8zjdIkZSlhESkDoJ1SRLUIExf4J3GMIymSJEUk9GuAd0IyM59pjAqfDwTLUewiN0ZxyjPw3wpaDIyJL8FlQ4WokzOUUxTGJHOOOO5vqEQJH07AEcpFjpyB47ooiZJEobB0hkh57gSo5yeJyAT8mXMEVKlP0iOO+psqDoSQA2WmiPgYaVrACSfzyKgsDcFPK5qeRMzpMTTg4Lod0ibJTCdUejoTvI8YZqKnuAeHIiKw4waGKcUUkUAnHMwSFIZCRdpjKjzqqL+xYpybD6jHR/9Y9df2H6jagWPO6AP+m63TFaWd9iAax6qm0RgUhx+V46JcClGTG31y+FE5LspN4RerBKoUNyGIJiFHLNFQcxeJRiJOVXsYnaaUxUqyXoGgJaRRTDPT6PQo7xESIqXlQOwiLNk9GRnlVHJ4/oHGeCQR0gqa8DRgkPHTIPRD6e0mLNqVFkZVxjdImXnq0S5yPF4hDQ0fGewmrA5ppVwRykvklIxEJBQejUOUpqpPeE4V7bx0D5KUUgZZXjpARCBBKNT/VApudpE5Y68O6cLVIV24OqQLV4d04eqQLlwd0ntnthYbYId09e4niw2wQ3qISNnovTRCDmcTR4+bDhEp5mT4SsXx9OunB4/vENj4HYx8m1z/sJByKvucc4AkEcGCCyrgWU1w2iNwjjb4ajz65S9O2i/MTd+XYuoOsWw379AWpUNCKjISYpbEKqUeLTBVWFHMkda0JCIgblTDlCdhqJFMfIHi178AOtsfk5/wIfqwJez1CUDaE6//MMedut1F65CQ+gVCAxaq0LDhGBHPFARIu07GKOVxPv8egZUS5ZphxnKxjZRmac5ZkEYp6wVhxpHrJT5iUZgLjZWI4ATcvtTjpntRyXuCB2FvFHYUpCEKBg5KFMJhumOJsx99cvdQGlEBMa8Ow5iFvicM0gA5FVKWcRkFRTaP1LwAU9oNhQjU9u/UGUXCo3CGc9h1U9lDKCRaA1XRQ3GB5IB7ULS4qTY9DXIeutW9q+JgruH3FAoVy1h1Vw9En6wdCtJIQcoAKeyqSAyRDhhKASmD1LMapDQzSNNIKbLtpbyXhAmEU4bSvNAVXpyHYZYm5pqiRCzNUlwVAFUnOEAqPajg8LBEjSqkccYMUg5hvS0/ftcjdfLCN2WpTohHwUsLcBEgMXBRrmhOUjz/8wobCBRo0/UqI2OkUAZAeQk7EnMPHC1SAbAUCRQrBJBWN04AKi71EKkw2WOMNIAcH1EdpxVSdmA5/8n9xYbXVONLKZCp8aVEjkQcUsmkhANCQiuAi5rkiSAMXBRHUYzUuM+P9lmaipDTkuUu8OIeEbmkGVe+6DthyDXmbsrDmARc57wUhSfc0f2ovDSiPE+KykvzxSb8tT5/vNjwDvjpibjKfLice7ELrs1cF+4ONAuQUyjHdaliSLngtYVLC6Y4oq47+mrVpyiiKMYFy5kpGtq0urp6DzZbsF11hn+vrsL23urozNwBY3rm6drq1vikMzq5NQ4HrJz6cBrLoUUjdeP64zybdOzQwuViH8C7sqqn/CRp7/dxf21t7QlsH8N2DVJ+z2y/gANnzI7J3Wa7NrKsTN8zFluPq/275uDW6KRx3M9HXxuGAwE+HX9tbFovW6TM8km0n9ffPUdOHOcDvNNlmdHIXNm3/h8/szY9UNkide26RBKMLfplaIy9HervOAWN4lVGVtc1elP6w9kiLbFNhcs8jMvWlHGMrcgjc4sGu/p18hC19aQh61siBe9z262MGcat9YjKwUuteq2yAGN7Nz1cba01PMjaIWUAwcZNicJ9ZdElyMN2RWRhbpFlF6N7C35Ub9XdhvasHdLK+6z8qt+3geVhu+zsJrgMLft4rq7Z2S1M+0MqaYmJlWP1rDq5Bbauh3Bm+9L70JGu7QspQrllVs2tqpPIFqkzyG2fQ51VS8NFqWr118gWaWaZVe3s7JGWvaN4Lb0v2SL1LCFkVnb+jk39SeXl7pGun5vXXPHBH9x4PqtPv501cy7NB7XecvkWpOtnh7r0O/zleLeucGOXRif/9aux3dlLdYkdntr88vrYqiaGzublO2P1B18Pd15cb0nJtjafn57Tw09nvn7u+cqFOa2sPJ+OzfqDhzVhPdi55b0z0ot3rpwf6vf430Z7Ly+fnbO7dPXl+Vld+ebb2Uvfujxv9t2dzdnQrl3ZWFoeaun3eGO4s/HyNzsm5PW3T184OacLF/59iim7sTJvBFr5dDLO68C9JqyVG/NetX7tV/P6w7wbXD2/vDTSH/+0Md5d+nCW6aWvTyzNa/nmL9Nmrz5YrjP77uK02eb55RM1Wn45fytrdO5hLauTFz6d5HC9yerhuQmra/XcT56+OHvV9T98dur9WZ069etZw7Mvl+rSduLE1Zm7dPVmreHS+SkHXL9TH9zSN9NFxJ36iy7dfGGDdLMBwzSsJquTpycd63mD0cqnMxd1HnwGAGv0q5li7WI9qRNLH8wY/rkewomNKQhnrzQEd3s663/YcCOXLtsgvdgIa/dIP2owunBj5qLrv64neuqzGTf9l9oMCGm7MlPzfNAAYfnqpNWt2w1mG9Ol5IcNl11+C5GeutYh7ZB2SDukHdIOaYe0Q9oh7ZB2SDukHdIOaYf0wJDyhN5qQ+qYecl2i/T97+fI7h4pE5ZIHWGDtND8AJEyDhJCSIz/0fC+9OTpV0MbEeCeakfKYz6B9ONHOyGlhFkhxYmwQ5oFXzYjZaN0uHgQiqaXe/tEuvKXLMvzXr8/MJ0S8F9roT78a569tvnbf7QhZbjUzBKpwh5pRnqdEkJUHLvadJmQzUgZpSPDEuP//KEB6d89SGw5Sgf+r3qr/SPFg7LMs8wLDNF6Wg/xoF/2jE0P496PDUGd/DGJgsDEuYpywreR/vToJ7Pz06NHP4+Rpn4UQPp6Q9s/NiD97p94Uv9oevu/fLWYMsQ/1BpufIWH6fAyc4cab9B+M77J95wxhgbZX5qsPno1skn6mn1U88tT5aX/bRIzMPRhW6bONtKPK0f9+dGp998fI83HtuBU+KsGr1q68mWUhGGqtaswzmPe7KUiGBvm+G9/2qi327g4SqtDsS8PocYvnFetNT7lO1RPzylISik4xqmYzPhA9dSp73+ayPiSVLZCFLhfPLCo8bHL7Wr8LN1sr56E6QTwVjWi0pnqyTjqz+ClH388V5YK164RVf1Ga4PUeXMaUYtEOtQU0u8fnfp+oiyd0nFtlx4Y0qEeffzuNfUPGOlUQ6pD2j2Qdkg7pB3SDmmHtEPaIe2Qdkg7pG810uvHBOn5BrOb00gbekHPdABuUrP/TSJdb+rgfGOyN3LTrylzvaCtkV7caED6ck+9oM+9rL9FS1emRwl8a9VNvUnXTzd41vRgkc26QRInZwZJPGi4PSszoBqRvj/bC/pWQ0/w5TuzffXr3y0v3Z4K0HlRa7a0cWd65Mn61xt1oyRu/mI33v5BLdMZWID+xvagm8ahPOsP65heWHk+Nzyqsa/+7Eilq7eX5yEsLX1wa8bu3J/h6Jzd8sYMhPVfbm+MBt8sn9jeuX159rLr1z64MqvvvvnWCihC7NrD0yuVfny4Mtbp53NDfPj6SNG58d76LKpzN14HsR3UwxvzQ7XWb9iNKEHrL17evjnSD+Od89/MZ8Bbd87fnNXtly9mWbGLLy4Pded/RjuXX8yNiIObvn5ppH+Ody61jYmb0NlXm69ewX/8u1cjbV6vGdc2ktxptgV2fTuIf462m9frcovtuCeI3eZvh/rN3/93tHerLnHr14cn/++3r3V2p7k2rCahQNRycolaKbuh/AEe2JQoyi7GuxDBFrM9ItZrn3rIiOd2k0FkFtN6NMYlw9hiUjmBMW6YNWgqNK99NoxdCuJnMYRZ4Z7VvYwxHlh4ELWeA6U2LqD20b+J+Um4fehvgfEeBgjvJDNtRPvaJ2a+Dpt7yUwvgazVzDF9GKzH488qTnHPbb1vTCfYi9thmV/jbfKptZwIbmXZGj9ld8eR0B5O4tYszdwQR+7efcNuxglqg0p6JfZ2tZ5OmxypcFQ30d60Eij7AquR9hpbzVhCbYq5Jjl9Kw+ndsW6tp5CwFbS6sLconSopO1YUauplBpkOeMEtYuz3k/jo1YSBxZ1r8Ce3ewurh2r/SG1m3FC2rQLrJ1gF5KDnaZ0G4sPMrvaJLZrkuwTqdUES5ZILWO8C1kiLS0r6ENB6uVWSHOrjL94pGKQW2Rpa6TqEJDazmbJrJ4mYrzota25bm/0gFV/YNmMtEvH/pCCVJsHBrZzgDF5RAtftDe0dqW9IjX+6TjMzDI9+mNqwv6xUmaP9LBn8nSQM76HrbltPM2ozXSjBmllNh1o2zdd349RFmjPIQFKk9RnmKPe7AO1U/S1iJQL7U2q21pIZvr6Q5VHxKiJkU4sg5zUvhUIR5YWkzsjToQw9bH2J5nylpdCwpMy4x4FFyQRg2o1ISnh8807M5d9kHBPyMgskrtzmO3Pz3YiQUR1kvmhoKbZ7Pm+5n6UJVPXd6QICDOTmUtoRREmpBAMMSEyBeckn7ENlZn83EHpZJ4WFIKHABxHwNfR0AKyrUBFyRFX5i8TFBQsTPDSXKy5NqceKSjyZIVUhlC1xDRRNUUItPjMRNNEe2G0EzKdSEiMGTujIenwiOxKpPY2VS3PmFmC04OwEi+u1mWA7COm58Z2fK0HlATA1ddi4FKsSU4gXTLXPIrV5FouLEwVPIrmSidTSN2A5OCPJA35wIdvsEi7HnhGEdC4L2lAwF4FKuNBDAHzkpGAeI2loJnKm6JsiNTpGW6OV9Z4tkFKzSrDOz9rkhBRUUreMwuNpK6OeRnzbG+FK8tMshkgdQIBYUB1Dk09OY3UrEtg4l4EJtKl4NCO8gBpACnhLiLexNO8WVsjLMyCO4GcQGqSLT1qFiPgcAdUWCRwE4kpQzRcTQmhBqJnBmP5CskAvBR4subsT4OEIF+gEKKFSARfQmmdF2bgyxRFBRQTO83q5mShAxmfDxyzhEUkAxqG0rdkOCsRmLgDUqXNLR34EIlZpGbi7QS8FOkgIAZp9hopCyMdTTzIGVS6CL0wCcgEUlkmSRKg2AuUWXiDRmkWhgGpamO4Go0CgmVpLKMKaU8a17ea9nc8j21YV//ImFNwaIForHZ0OpXJCincdJowL6ZRbDlRbY2KvELqB2kvBS+l0RzS2K+81DQAZe6UUIRys2xMhdQsXBJNJIaaZV8KA7MQE0gFPEY45k/ZF1CLxKGG20RkZLyU5igjUCMZL5UkURDwyEsH7Q1FGo2wO3ofrTviFj7U+BxDzGgkkS65u9f30RxyeWaQirJafnsgTFk6gxRFWmNASgOapghLAblYeTQI4ITANB5MlJlOCmWpQr0i9pzJGl8HNJOQ9XXEylAFggdunJlVp5JY9GQaUR87KiKeJB6JAqjxYesd3tJnlJo6xTFrPptqFBq7fK/zqDM/icAdAlGYhzGob7jJ5rPrcDuUCmboU1MqsGpRArOOCeLUEVTyqfsJtny4rtnUU0EVVQjAAf8zg7hYtfIZr0KU5gJm/eqhDRfmChDwHhP1zklYvso6dpI+1CfpgST++C0h2eng9PjJUcegVU2zsL+pevaGrEuzg5rWCnhT9ezzo45Bm5y3DemTL446Bq1a9JLhB637X7whyycdL50xuofQltmaXHbf7IBvOGb7FJh/Mj7ztNUU3TuzOjY0S9RNGd4bXw22q2Y7aeqMTVdnTd8+ff4eCLCsmq1pAZyB7eNPgIc58wwSd9+ceQpnno1Mt2ZNnbHps7tV2ffEHAdOW49hx2B5ag6YLGwOvAfb+7DzeNL03pTp6oTpO65na0+POgrHTWcWvdBvp6cd0kXr6d2G9ZE67Vkd0U6dOnXq1Ok4ypHVL3Pzx6Pul7g9SpRE1fayCTqke5TpO8V7DCkzUst8yFgxxGPiSURjigSNO7a7U9UdbSDSQmTM1VJLj8cRigrSFyRlPqV98o7+wLxnVUhLWabpgJhOXilBLKMBZHwRRKmn6V57ar27MkhZZvr+wAaZrnjbSD2T48lCx2S+ExI9JiKFEsVdlrrCpR5TieNROhAqYUR2SHctprQ23uiaT1cLRM2oFKEJFKFEKyQOr59Yp06dOnXqdAT6f2PCMLGwm5Z8AAAAAElFTkSuQmCC)'''))" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/markdown": "## Architecture For Neural Machine Trans\n![Architecture Neural Machine Trans](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAVIAAACVCAMAAAA9kYJlAAABa1BMVEX///+S0FBEcsQAsPAAAACU004oW638/Py+vr7Pz8/39/f/3ZrY2ur/0GcAs/Pa3uYAb7OU1EhrpF1BccagqsI4a8AAgMP/7s//vQDl5vLv7+9tiZzs6+/j4+NhmmJNg2t1jKVGe26DwU+Ol7W2vc/IztopVZ1VcKd+u1LZ2dmMlsDKzeJzrlmutNKhoaG3vNdRUVFtbW1eXl59fX1qamq1tbWmpqZ3d3eLi4tHR0c3NzeXl5e5ubklJSUxMTEfHx9fg4sTExP/xiH/9eMyvPL/8s3/9+L/6sB6j7dge64AnN85X6Pf9P3/zVr/3oz/2G//56j/y0X/13dJaKRLS1KusL1FR2AAADJTc3t+mKomOyBSf0phlG9gjIORortkmWkARplRg7ZcgzJEYCVUg3yK0va8wc5/hI4SHjMJEiEwUIo6UiN2qEAxMlI6VzolQHC14vljyPUzLzgeN2JqgqQvRz9hYnOJj6xfZYZlJtHGAAAShElEQVR4nO2djX/UNprHBShYErRsgGyCiCgNoaCD9fpNfpM9w0AbaHkphe217NH29qXs7l332u719u7Pv0eemTAvdqwkkwSCfx8YO/YzsvT1oxd79IJQp06dOnXq1KlTp7dFceIfdRSOmUKNuJr429UTf0R0/gt9dtBRessleuaTpUng8CSMCM8zjdIkZSlhESkDoJ1SRLUIExf4J3GMIymSJEUk9GuAd0IyM59pjAqfDwTLUewiN0ZxyjPw3wpaDIyJL8FlQ4WokzOUUxTGJHOOOO5vqEQJH07AEcpFjpyB47ooiZJEobB0hkh57gSo5yeJyAT8mXMEVKlP0iOO+psqDoSQA2WmiPgYaVrACSfzyKgsDcFPK5qeRMzpMTTg4Lod0ibJTCdUejoTvI8YZqKnuAeHIiKw4waGKcUUkUAnHMwSFIZCRdpjKjzqqL+xYpybD6jHR/9Y9df2H6jagWPO6AP+m63TFaWd9iAax6qm0RgUhx+V46JcClGTG31y+FE5LspN4RerBKoUNyGIJiFHLNFQcxeJRiJOVXsYnaaUxUqyXoGgJaRRTDPT6PQo7xESIqXlQOwiLNk9GRnlVHJ4/oHGeCQR0gqa8DRgkPHTIPRD6e0mLNqVFkZVxjdImXnq0S5yPF4hDQ0fGewmrA5ppVwRykvklIxEJBQejUOUpqpPeE4V7bx0D5KUUgZZXjpARCBBKNT/VApudpE5Y68O6cLVIV24OqQLV4d04eqQLlwd0ntnthYbYId09e4niw2wQ3qISNnovTRCDmcTR4+bDhEp5mT4SsXx9OunB4/vENj4HYx8m1z/sJByKvucc4AkEcGCCyrgWU1w2iNwjjb4ajz65S9O2i/MTd+XYuoOsWw379AWpUNCKjISYpbEKqUeLTBVWFHMkda0JCIgblTDlCdhqJFMfIHi178AOtsfk5/wIfqwJez1CUDaE6//MMedut1F65CQ+gVCAxaq0LDhGBHPFARIu07GKOVxPv8egZUS5ZphxnKxjZRmac5ZkEYp6wVhxpHrJT5iUZgLjZWI4ATcvtTjpntRyXuCB2FvFHYUpCEKBg5KFMJhumOJsx99cvdQGlEBMa8Ow5iFvicM0gA5FVKWcRkFRTaP1LwAU9oNhQjU9u/UGUXCo3CGc9h1U9lDKCRaA1XRQ3GB5IB7ULS4qTY9DXIeutW9q+JgruH3FAoVy1h1Vw9En6wdCtJIQcoAKeyqSAyRDhhKASmD1LMapDQzSNNIKbLtpbyXhAmEU4bSvNAVXpyHYZYm5pqiRCzNUlwVAFUnOEAqPajg8LBEjSqkccYMUg5hvS0/ftcjdfLCN2WpTohHwUsLcBEgMXBRrmhOUjz/8wobCBRo0/UqI2OkUAZAeQk7EnMPHC1SAbAUCRQrBJBWN04AKi71EKkw2WOMNIAcH1EdpxVSdmA5/8n9xYbXVONLKZCp8aVEjkQcUsmkhANCQiuAi5rkiSAMXBRHUYzUuM+P9lmaipDTkuUu8OIeEbmkGVe+6DthyDXmbsrDmARc57wUhSfc0f2ovDSiPE+KykvzxSb8tT5/vNjwDvjpibjKfLice7ELrs1cF+4ONAuQUyjHdaliSLngtYVLC6Y4oq47+mrVpyiiKMYFy5kpGtq0urp6DzZbsF11hn+vrsL23urozNwBY3rm6drq1vikMzq5NQ4HrJz6cBrLoUUjdeP64zybdOzQwuViH8C7sqqn/CRp7/dxf21t7QlsH8N2DVJ+z2y/gANnzI7J3Wa7NrKsTN8zFluPq/275uDW6KRx3M9HXxuGAwE+HX9tbFovW6TM8km0n9ffPUdOHOcDvNNlmdHIXNm3/h8/szY9UNkide26RBKMLfplaIy9HervOAWN4lVGVtc1elP6w9kiLbFNhcs8jMvWlHGMrcgjc4sGu/p18hC19aQh61siBe9z262MGcat9YjKwUuteq2yAGN7Nz1cba01PMjaIWUAwcZNicJ9ZdElyMN2RWRhbpFlF6N7C35Ub9XdhvasHdLK+6z8qt+3geVhu+zsJrgMLft4rq7Z2S1M+0MqaYmJlWP1rDq5Bbauh3Bm+9L70JGu7QspQrllVs2tqpPIFqkzyG2fQ51VS8NFqWr118gWaWaZVe3s7JGWvaN4Lb0v2SL1LCFkVnb+jk39SeXl7pGun5vXXPHBH9x4PqtPv501cy7NB7XecvkWpOtnh7r0O/zleLeucGOXRif/9aux3dlLdYkdntr88vrYqiaGzublO2P1B18Pd15cb0nJtjafn57Tw09nvn7u+cqFOa2sPJ+OzfqDhzVhPdi55b0z0ot3rpwf6vf430Z7Ly+fnbO7dPXl+Vld+ebb2Uvfujxv9t2dzdnQrl3ZWFoeaun3eGO4s/HyNzsm5PW3T184OacLF/59iim7sTJvBFr5dDLO68C9JqyVG/NetX7tV/P6w7wbXD2/vDTSH/+0Md5d+nCW6aWvTyzNa/nmL9Nmrz5YrjP77uK02eb55RM1Wn45fytrdO5hLauTFz6d5HC9yerhuQmra/XcT56+OHvV9T98dur9WZ069etZw7Mvl+rSduLE1Zm7dPVmreHS+SkHXL9TH9zSN9NFxJ36iy7dfGGDdLMBwzSsJquTpycd63mD0cqnMxd1HnwGAGv0q5li7WI9qRNLH8wY/rkewomNKQhnrzQEd3s663/YcCOXLtsgvdgIa/dIP2owunBj5qLrv64neuqzGTf9l9oMCGm7MlPzfNAAYfnqpNWt2w1mG9Ol5IcNl11+C5GeutYh7ZB2SDukHdIOaYe0Q9oh7ZB2SDukHdIOaYf0wJDyhN5qQ+qYecl2i/T97+fI7h4pE5ZIHWGDtND8AJEyDhJCSIz/0fC+9OTpV0MbEeCeakfKYz6B9ONHOyGlhFkhxYmwQ5oFXzYjZaN0uHgQiqaXe/tEuvKXLMvzXr8/MJ0S8F9roT78a569tvnbf7QhZbjUzBKpwh5pRnqdEkJUHLvadJmQzUgZpSPDEuP//KEB6d89SGw5Sgf+r3qr/SPFg7LMs8wLDNF6Wg/xoF/2jE0P496PDUGd/DGJgsDEuYpywreR/vToJ7Pz06NHP4+Rpn4UQPp6Q9s/NiD97p94Uv9oevu/fLWYMsQ/1BpufIWH6fAyc4cab9B+M77J95wxhgbZX5qsPno1skn6mn1U88tT5aX/bRIzMPRhW6bONtKPK0f9+dGp998fI83HtuBU+KsGr1q68mWUhGGqtaswzmPe7KUiGBvm+G9/2qi327g4SqtDsS8PocYvnFetNT7lO1RPzylISik4xqmYzPhA9dSp73+ayPiSVLZCFLhfPLCo8bHL7Wr8LN1sr56E6QTwVjWi0pnqyTjqz+ClH388V5YK164RVf1Ga4PUeXMaUYtEOtQU0u8fnfp+oiyd0nFtlx4Y0qEeffzuNfUPGOlUQ6pD2j2Qdkg7pB3SDmmHtEPaIe2Qdkg7pG810uvHBOn5BrOb00gbekHPdABuUrP/TSJdb+rgfGOyN3LTrylzvaCtkV7caED6ck+9oM+9rL9FS1emRwl8a9VNvUnXTzd41vRgkc26QRInZwZJPGi4PSszoBqRvj/bC/pWQ0/w5TuzffXr3y0v3Z4K0HlRa7a0cWd65Mn61xt1oyRu/mI33v5BLdMZWID+xvagm8ahPOsP65heWHk+Nzyqsa/+7Eilq7eX5yEsLX1wa8bu3J/h6Jzd8sYMhPVfbm+MBt8sn9jeuX159rLr1z64MqvvvvnWCihC7NrD0yuVfny4Mtbp53NDfPj6SNG58d76LKpzN14HsR3UwxvzQ7XWb9iNKEHrL17evjnSD+Od89/MZ8Bbd87fnNXtly9mWbGLLy4Pded/RjuXX8yNiIObvn5ppH+Ody61jYmb0NlXm69ewX/8u1cjbV6vGdc2ktxptgV2fTuIf462m9frcovtuCeI3eZvh/rN3/93tHerLnHr14cn/++3r3V2p7k2rCahQNRycolaKbuh/AEe2JQoyi7GuxDBFrM9ItZrn3rIiOd2k0FkFtN6NMYlw9hiUjmBMW6YNWgqNK99NoxdCuJnMYRZ4Z7VvYwxHlh4ELWeA6U2LqD20b+J+Um4fehvgfEeBgjvJDNtRPvaJ2a+Dpt7yUwvgazVzDF9GKzH488qTnHPbb1vTCfYi9thmV/jbfKptZwIbmXZGj9ld8eR0B5O4tYszdwQR+7efcNuxglqg0p6JfZ2tZ5OmxypcFQ30d60Eij7AquR9hpbzVhCbYq5Jjl9Kw+ndsW6tp5CwFbS6sLconSopO1YUauplBpkOeMEtYuz3k/jo1YSBxZ1r8Ce3ewurh2r/SG1m3FC2rQLrJ1gF5KDnaZ0G4sPMrvaJLZrkuwTqdUES5ZILWO8C1kiLS0r6ENB6uVWSHOrjL94pGKQW2Rpa6TqEJDazmbJrJ4mYrzota25bm/0gFV/YNmMtEvH/pCCVJsHBrZzgDF5RAtftDe0dqW9IjX+6TjMzDI9+mNqwv6xUmaP9LBn8nSQM76HrbltPM2ozXSjBmllNh1o2zdd349RFmjPIQFKk9RnmKPe7AO1U/S1iJQL7U2q21pIZvr6Q5VHxKiJkU4sg5zUvhUIR5YWkzsjToQw9bH2J5nylpdCwpMy4x4FFyQRg2o1ISnh8807M5d9kHBPyMgskrtzmO3Pz3YiQUR1kvmhoKbZ7Pm+5n6UJVPXd6QICDOTmUtoRREmpBAMMSEyBeckn7ENlZn83EHpZJ4WFIKHABxHwNfR0AKyrUBFyRFX5i8TFBQsTPDSXKy5NqceKSjyZIVUhlC1xDRRNUUItPjMRNNEe2G0EzKdSEiMGTujIenwiOxKpPY2VS3PmFmC04OwEi+u1mWA7COm58Z2fK0HlATA1ddi4FKsSU4gXTLXPIrV5FouLEwVPIrmSidTSN2A5OCPJA35wIdvsEi7HnhGEdC4L2lAwF4FKuNBDAHzkpGAeI2loJnKm6JsiNTpGW6OV9Z4tkFKzSrDOz9rkhBRUUreMwuNpK6OeRnzbG+FK8tMshkgdQIBYUB1Dk09OY3UrEtg4l4EJtKl4NCO8gBpACnhLiLexNO8WVsjLMyCO4GcQGqSLT1qFiPgcAdUWCRwE4kpQzRcTQmhBqJnBmP5CskAvBR4subsT4OEIF+gEKKFSARfQmmdF2bgyxRFBRQTO83q5mShAxmfDxyzhEUkAxqG0rdkOCsRmLgDUqXNLR34EIlZpGbi7QS8FOkgIAZp9hopCyMdTTzIGVS6CL0wCcgEUlkmSRKg2AuUWXiDRmkWhgGpamO4Go0CgmVpLKMKaU8a17ea9nc8j21YV//ImFNwaIForHZ0OpXJCincdJowL6ZRbDlRbY2KvELqB2kvBS+l0RzS2K+81DQAZe6UUIRys2xMhdQsXBJNJIaaZV8KA7MQE0gFPEY45k/ZF1CLxKGG20RkZLyU5igjUCMZL5UkURDwyEsH7Q1FGo2wO3ofrTviFj7U+BxDzGgkkS65u9f30RxyeWaQirJafnsgTFk6gxRFWmNASgOapghLAblYeTQI4ITANB5MlJlOCmWpQr0i9pzJGl8HNJOQ9XXEylAFggdunJlVp5JY9GQaUR87KiKeJB6JAqjxYesd3tJnlJo6xTFrPptqFBq7fK/zqDM/icAdAlGYhzGob7jJ5rPrcDuUCmboU1MqsGpRArOOCeLUEVTyqfsJtny4rtnUU0EVVQjAAf8zg7hYtfIZr0KU5gJm/eqhDRfmChDwHhP1zklYvso6dpI+1CfpgST++C0h2eng9PjJUcegVU2zsL+pevaGrEuzg5rWCnhT9ezzo45Bm5y3DemTL446Bq1a9JLhB637X7whyycdL50xuofQltmaXHbf7IBvOGb7FJh/Mj7ztNUU3TuzOjY0S9RNGd4bXw22q2Y7aeqMTVdnTd8+ff4eCLCsmq1pAZyB7eNPgIc58wwSd9+ceQpnno1Mt2ZNnbHps7tV2ffEHAdOW49hx2B5ag6YLGwOvAfb+7DzeNL03pTp6oTpO65na0+POgrHTWcWvdBvp6cd0kXr6d2G9ZE67Vkd0U6dOnXq1Ok4ypHVL3Pzx6Pul7g9SpRE1fayCTqke5TpO8V7DCkzUst8yFgxxGPiSURjigSNO7a7U9UdbSDSQmTM1VJLj8cRigrSFyRlPqV98o7+wLxnVUhLWabpgJhOXilBLKMBZHwRRKmn6V57ar27MkhZZvr+wAaZrnjbSD2T48lCx2S+ExI9JiKFEsVdlrrCpR5TieNROhAqYUR2SHctprQ23uiaT1cLRM2oFKEJFKFEKyQOr59Yp06dOnXqdAT6f2PCMLGwm5Z8AAAAAElFTkSuQmCC)", | |
"text/plain": [ | |
"<IPython.core.display.Markdown object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "cJ9Cfl5D9qYN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"if not colab:\n", | |
" # if on local machine \n", | |
" root_dir='.'\n", | |
" \n", | |
"else:\n", | |
" # if using google colab use this code\n", | |
" from google.colab import drive\n", | |
" drive.mount('/content/drive')\n", | |
" root_dir = \"/content/drive/My Drive/Colab Notebooks\"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ct0UBp8y9qYS", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"data_path = os.path.join(root_dir, \"fra.csv\")\n", | |
"doc = pd.read_csv(data_path, nrows=training_samples)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tvVhQmsq9qYY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# replace contracted forms for english words\n", | |
"contracted_dict={\"won't\" : \"will not\", \"can\\'t\" : \"can not\", \"n\\'t\" : \" not\", \"\\'re\" : \" are\", \"\\'s\" : \" is\", \"\\'d\" : \" would\", \"\\'ll\" : \" will\", \"\\'t\" : \" not\", \"\\'ve\" : \" have\", \"\\'m\" : \" am\"}\n", | |
"\n", | |
"def replace_contracted(text):\n", | |
"\n", | |
" regex = re.compile(\"|\".join(map(re.escape, contracted_dict.keys( ))))\n", | |
" return regex.sub(lambda match: contracted_dict[match.group(0)], text)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WPWsAj-u9qYd", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# apply decontraction and lowercase\n", | |
"doc=doc.apply(np.vectorize(lambda sent : replace_contracted(str(sent).strip().lower())))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "54f6A85j9qYg", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# tokenize sentences and add start_ and _end keyword to target sentences\n", | |
"source_sents=doc.Source.apply(lambda x: x + ' _END').apply(lambda sent: word_tokenize(sent))\n", | |
"target_sents=doc.Target.apply(lambda x : 'START_ '+ x + ' _END').apply(lambda sent: word_tokenize(sent))\n", | |
"temp = list(zip(source_sents, target_sents)) \n", | |
"random.shuffle(temp) \n", | |
"source_sents, target_sents = zip(*temp)\n", | |
"source_sents, target_sents = pd.Series(source_sents), pd.Series(target_sents)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"scrolled": true, | |
"id": "c0PwKMAA9qYj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"del(doc)\n", | |
"# building the vocabulary\n", | |
"source_vocab=set().union(*source_sents)\n", | |
"target_vocab=set().union(*target_sents)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "KSuPWE2C9qYo", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# max sentence length for each language in the dataset\n", | |
"max_source_len=max(source_sents.apply(len))\n", | |
"max_target_len=max(target_sents.apply(len))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Yhpt_X7Y9qYr", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# numeric identity for each word in vocab\n", | |
"source_wordint_rel=bidict(enumerate(source_vocab, 1))\n", | |
"temp={0:'paddingZero'}\n", | |
"temp.update(dict(enumerate(target_vocab, 1)))\n", | |
"target_wordint_rel=bidict(temp)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "h0_0s05X9qYv", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# prepare inputs and outputs\n", | |
"encoder_source_arr=[list(map(lambda word : source_wordint_rel.inv[word], sent)) for sent in source_sents]\n", | |
"decoder_source_arr=[list(map(lambda word : target_wordint_rel.inv[word], sent)) for sent in target_sents]\n", | |
"decoder_output_arr=[list(map(lambda word : target_wordint_rel.inv[word], sent[1:])) for sent in target_sents]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"scrolled": true, | |
"id": "WtaBqNMZ9qYy", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# pad the inputs and outputs to max length\n", | |
"padded_encoder_source_arr=pad_sequences(encoder_source_arr, maxlen=max_source_len, padding='post')\n", | |
"padded_decoder_source_arr=pad_sequences(decoder_source_arr, maxlen=max_target_len, padding='post')\n", | |
"padded_decoder_output_arr=pad_sequences(decoder_output_arr, maxlen=max_target_len, padding='post')\n", | |
"onehotted_decoder_output_arr=tf.one_hot(padded_decoder_output_arr, len(target_vocab)+1).numpy()\n", | |
"\n", | |
"del encoder_source_arr, decoder_source_arr, decoder_output_arr, padded_decoder_output_arr" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "T9AF9kdc9qY1", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Model Preparation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "YH3L903Z9qY2", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# context-vector length\n", | |
"latent_dim=ctx_vec_len\n", | |
"\n", | |
"# this is the source languge consumtion layer\n", | |
"encoder_inputs = Input(shape=(None,), name='encoder_sources')\n", | |
"# embed the 2-d source into 3-d\n", | |
"enc_emb = Embedding(len(source_vocab)+1, embedding_dim, mask_zero = True, name='enc_emb')(encoder_inputs)\n", | |
"\n", | |
"# LSTM layer to encode the source sentence into context-vector representation\n", | |
"encoder_lstm = Bidirectional(LSTM(latent_dim, return_state=True, return_sequences=True, name='encoder_lstm1', dropout=dropout), name='encoder_bi-lstm1', merge_mode=\"concat\")\n", | |
"\n", | |
"encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder_lstm(enc_emb)\n", | |
"encoder_states = [forward_h, forward_c, backward_h, backward_c]\n", | |
"\n", | |
"encoder_lstm1 = Bidirectional(LSTM(latent_dim, return_state=True, name='encoder_lstm2', dropout=dropout), name='encoder_bi-lstm2', merge_mode=\"concat\")\n", | |
"encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder_lstm1(encoder_outputs, initial_state=encoder_states)\n", | |
"\n", | |
"state_h = Concatenate()([forward_h, backward_h])\n", | |
"state_c = Concatenate()([forward_c, backward_c])\n", | |
"# encoded-states tensor stores the context-vector\n", | |
"encoder_states = [state_h, state_c]" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WOssNVT79qY6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# this is the target languge consumtion layer\n", | |
"decoder_inputs = Input(shape=(None,), name='decoder_sources')\n", | |
"# embed the 2-d source into 3-d\n", | |
"dec_emb_layer = Embedding(len(target_vocab)+1, embedding_dim, mask_zero = True, name='dec_emb_layer')\n", | |
"dec_emb = dec_emb_layer(decoder_inputs)\n", | |
"\n", | |
"# decoder LSTM, this takes in the context-vector and starting or so-far decoded part of the target sentence\n", | |
"decoder_lstm1 = LSTM(latent_dim, return_sequences=True, name='decoder_lstm1', dropout=dropout)\n", | |
"decoder_outputs11 = decoder_lstm1(dec_emb)\n", | |
"decoder_lstm2 = LSTM(latent_dim, return_sequences=True, name='decoder_lstm2', dropout=dropout)\n", | |
"decoder_outputs12 = decoder_lstm2(decoder_outputs11)\n", | |
"decoder_lstm3 = LSTM(latent_dim*2, return_sequences=True, return_state=True, name='decoder_lstm', dropout=dropout)\n", | |
"decoder_outputs13, _, _ = decoder_lstm3(decoder_outputs12, initial_state=encoder_states)\n", | |
"\n", | |
"# final layer that gives a probabilty distribution of the next possible words\n", | |
"decoder_dense = Dense(len(target_vocab)+1, activation='softmax', name='decoder_dense')\n", | |
"decoder_outputs14 = decoder_dense(decoder_outputs13)\n", | |
"\n", | |
"# Encode the source sequence to get the \"Context vectors\"\n", | |
"encoder_model = Model(encoder_inputs, encoder_states, name='Model_Encoder')\n", | |
"encoder_model.summary()\n", | |
"plot_model(encoder_model, show_shapes=True, show_layer_names=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "rYR6yV_KZgvG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Custom Loss Function to get rid of padding" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "VNE30xzaZZX7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"vocab_len=len(onehotted_decoder_output_arr[0][0])\n", | |
"\n", | |
"def PaddedCategoricalCrossentropy(eps=1e-12):\n", | |
" def loss(y_true, y_pred):\n", | |
" mask_value = np.zeros((vocab_len))\n", | |
" mask_value[0] = 1\n", | |
" # find out which timesteps in `y_true` are not the padding character \n", | |
" mask = K.equal(y_true, mask_value)\n", | |
" mask = 1 - K.cast(mask, K.floatx())\n", | |
" mask = K.sum(mask,axis=2)/2\n", | |
" # multplying the loss by the mask. the loss for padding will be zero\n", | |
" loss = tf.keras.layers.multiply([K.categorical_crossentropy(y_true, y_pred), mask])\n", | |
" return K.sum(loss) / K.sum(mask)\n", | |
" return loss" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tiYu7bVY9qY9", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# model building and summary\n", | |
"model = Model([encoder_inputs, decoder_inputs], decoder_outputs14, name='Model_Translation')\n", | |
"model.compile(optimizer='Adam', loss=PaddedCategoricalCrossentropy(), metrics=['acc'])\n", | |
"model.summary()\n", | |
"plot_model(model, show_shapes=True, show_layer_names=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "wjYkSeMjVz04", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "g_dxY6fL9qZB", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# TensorBoard Callback \n", | |
"tbCallBack = TensorBoard(log_dir=os.path.join(root_dir, 'Graph'), histogram_freq=0, write_graph=True, write_images=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"scrolled": true, | |
"id": "6SsaUZXF9qZF", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"if training:\n", | |
" # train the model\n", | |
" history=model.fit([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr, epochs=epochs, validation_split=0.02, callbacks=[tbCallBack], batch_size=batch_size)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "a6-S3Z9QH3iX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"if training:\n", | |
" model.save_weights(os.path.join(root_dir, weight_file))\n", | |
" with plt.style.context('dark_background'):\n", | |
" plt.plot(history.history['acc'])\n", | |
" plt.plot(history.history['val_acc'])\n", | |
" plt.title('model accuracy')\n", | |
" plt.ylabel('accuracy')\n", | |
" plt.xlabel('epoch')\n", | |
" plt.legend(['train', 'val'], loc='upper left')\n", | |
" plt.show()\n", | |
" plt.plot(history.history['loss'])\n", | |
" plt.plot(history.history['val_loss'])\n", | |
" plt.title('model loss')\n", | |
" plt.ylabel('loss')\n", | |
" plt.xlabel('epoch')\n", | |
" plt.legend(['train', 'val'], loc='upper left')\n", | |
" plt.show()\n", | |
" print(f'Accuracy while saving is {model.evaluate([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr)}')" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "CZpJP5Fk9qZI", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Decoder Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ksm-Vbao9qZJ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Decoder setup\n", | |
"# Below tensors will hold the states of the previous time step\n", | |
"state_h = Input(shape=(latent_dim*2,))\n", | |
"state_c = Input(shape=(latent_dim*2,))\n", | |
"\n", | |
"\n", | |
"decoder_state_input = [state_h, state_c]\n", | |
"# Get the embeddings of the decoder sequence\n", | |
"dec_emb2= dec_emb_layer(decoder_inputs)\n", | |
"# To predict the next word in the sequence, set the initial states to the states from the previous time step\n", | |
"decoder_outputs21 = decoder_lstm1(dec_emb2)\n", | |
"decoder_outputs22 = decoder_lstm2(decoder_outputs21)\n", | |
"decoder_outputs23, state_h2, state_c2 = decoder_lstm3(decoder_outputs22, initial_state=decoder_state_input)\n", | |
"decoder_states2 = [state_h2, state_c2]\n", | |
"# A dense softmax layer to generate prob dist. over the target vocabulary\n", | |
"decoder_outputs24 = decoder_dense(decoder_outputs23)\n", | |
"# Final decoder model\n", | |
"decoder_model = Model(\n", | |
" [decoder_inputs] + decoder_state_input,\n", | |
" [decoder_outputs24] + decoder_states2, name='Model_Decoder')\n", | |
"decoder_model.summary()\n", | |
"plot_model(decoder_model, show_shapes=True, show_layer_names=True)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5yRFPYsS9qZM", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Decoding Logic" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_hAps5WG9qZN", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def decode_sequence(source_seq):\n", | |
" \n", | |
" # Encode the source as state vectors.\n", | |
" states_value = encoder_model.predict(source_seq)\n", | |
" # Generate empty target sequence of length 1.\n", | |
" target_seq = np.zeros((1,1))\n", | |
" # Populate the first character of \n", | |
" #target sequence with the start character.\n", | |
" target_seq[0, 0] = target_wordint_rel.inv['START_']\n", | |
" # Sampling loop for a batch of sequences\n", | |
" # (to simplify, here we assume a batch of size 1).\n", | |
" stop_condition = False\n", | |
" decoded_sentence = []\n", | |
" while not stop_condition:\n", | |
" output_tokens, h, c = decoder_model.predict([target_seq] + states_value)\n", | |
" # Sample a token\n", | |
" sampled_token_index = np.argmax(output_tokens[0, -1, :])\n", | |
" sampled_word =target_wordint_rel[sampled_token_index]\n", | |
" decoded_sentence += [sampled_word]\n", | |
" # Exit condition: either hit max length\n", | |
" # or find stop character.\n", | |
" if (sampled_word == '_END' or\n", | |
" len(decoded_sentence) > 50):\n", | |
" stop_condition = True\n", | |
" # Update the target sequence (of length 1).\n", | |
" target_seq = np.zeros((1,1))\n", | |
" target_seq[0, 0] = sampled_token_index\n", | |
" # Update states\n", | |
" states_value = [h, c]\n", | |
"\n", | |
" return decoded_sentence" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EZFK5mg_9qZQ", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"scrolled": false, | |
"id": "mWwKlX5a9qZQ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"start=1000\n", | |
"offset=100\n", | |
"def calc_strdiff(true, pred):\n", | |
" # return sum([1 for char in list(difflib.ndiff(true, pred)) if '+ ' in char or '- ' in char])/(len(true))\n", | |
" return nltk.translate.bleu_score.sentence_bleu([word_tokenize(true)], word_tokenize(pred))\n", | |
" \n", | |
"if validation:\n", | |
" \n", | |
" model.load_weights(os.path.join(root_dir, weight_file))\n", | |
" print(f'Accuracy after loading is {model.evaluate([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr)}')\n", | |
" y_truePred = [(' '.join(source_sents[seq_index][:-1]), ' '.join(target_sents[seq_index][1:-1]), ' '.join(decode_sequence(padded_encoder_source_arr[seq_index:seq_index+1])[:-1])) for seq_index, _ in enumerate(padded_encoder_source_arr[start:start+offset], start)]\n", | |
" bleu_score=[calc_strdiff(true, pred) for _, true, pred, in y_truePred]\n", | |
" print(f'Bleu Scores are {bleu_score}')\n", | |
" print(f'Avg bleu score for {len(y_truePred)} tests was {sum(bleu_score)/len(y_truePred)}.')\n", | |
" print(f\"{pd.DataFrame(y_truePred, columns=['Source', 'Expected', 'Predicted'])}\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Nsddt0BCzEJz", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"pd.DataFrame(y_truePred, columns=['Source', 'Expected', 'Predicted']).to_excel(os.path.join(root_dir, 'review.xlsx'))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment