Created
January 25, 2021 05:59
-
-
Save seanbenhur/72024be5b416c70121a8741323b27dbd to your computer and use it in GitHub Desktop.
MLT .ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MLT .ipynb", | |
"provenance": [], | |
"mount_file_id": "19g3duD087i-qbuq3XutAsxUkD9rPx-Cm", | |
"authorship_tag": "ABX9TyOUYUTOfsDx2EpDyMYnS8b1", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/seanbenhur/72024be5b416c70121a8741323b27dbd/mlt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "XmWpFnT1jy_F", | |
"outputId": "a1cc29cc-c55f-48f6-b2d5-82fff06102dc" | |
}, | |
"source": [ | |
"!pip install torchtext==0.6.0" | |
], | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: torchtext==0.6.0 in /usr/local/lib/python3.6/dist-packages (0.6.0)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (2.23.0)\n", | |
"Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.7.0+cu101)\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.15.0)\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (1.19.5)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (4.41.1)\n", | |
"Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from torchtext==0.6.0) (0.1.95)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2.10)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (2020.12.5)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext==0.6.0) (1.24.3)\n", | |
"Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.16.0)\n", | |
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (3.7.4.3)\n", | |
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext==0.6.0) (0.8)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2_dl5Kn2o0Gj", | |
"outputId": "716550c1-6484-48d3-ced0-55872956b2b4" | |
}, | |
"source": [ | |
"!python -m spacy download de_core_news_sm" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: de_core_news_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz#egg=de_core_news_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)\n", | |
"Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from de_core_news_sm==2.2.5) (2.2.4)\n", | |
"Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.5)\n", | |
"Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.0)\n", | |
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.19.5)\n", | |
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.0.5)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (51.3.3)\n", | |
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (4.41.1)\n", | |
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.8.0)\n", | |
"Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (1.1.3)\n", | |
"Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (7.4.0)\n", | |
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.5)\n", | |
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.0.5)\n", | |
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (2.23.0)\n", | |
"Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->de_core_news_sm==2.2.5) (0.4.1)\n", | |
"Requirement already satisfied: importlib-metadata>=0.20; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.3.0)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2.10)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (1.24.3)\n", | |
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.0.4)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->de_core_news_sm==2.2.5) (2020.12.5)\n", | |
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.4.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.6.4; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < \"3.8\"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->de_core_news_sm==2.2.5) (3.7.4.3)\n", | |
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n", | |
"You can now load the model via spacy.load('de_core_news_sm')\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "bnw2p2LSilxM" | |
}, | |
"source": [ | |
"import torchtext\r\n", | |
"from torchtext import data\r\n", | |
"from torchtext.data import Field,BucketIterator\r\n", | |
"from torchtext.datasets import Multi30k\r\n", | |
"from torchtext.data.metrics import bleu_score\r\n", | |
"import random\r\n", | |
"import torch\r\n", | |
"import torch.nn as nn\r\n", | |
"import torch.nn.functional as F\r\n", | |
"import torch.optim as optim \r\n", | |
"import numpy as np\r\n", | |
"import spacy" | |
], | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DcwuzwVgFMo3" | |
}, | |
"source": [ | |
"import os\r\n", | |
"os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"" | |
], | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kTMFT1sYkIhg" | |
}, | |
"source": [ | |
"def set_seed(seed):\r\n", | |
" random.seed(seed)\r\n", | |
" np.random.seed(seed)\r\n", | |
" torch.manual_seed(seed)\r\n", | |
" torch.cuda.manual_seed(seed)\r\n", | |
" torch.backends.cudnn.deterministic = True" | |
], | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Emi5IU8Nld1Y" | |
}, | |
"source": [ | |
"seed = 1234\r\n", | |
"set_seed(seed)" | |
], | |
"execution_count": 26, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BhEo62Y1mknd" | |
}, | |
"source": [ | |
"spacy_en = spacy.load('en')\r\n", | |
"spacy_de = spacy.load('de_core_news_sm')" | |
], | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4KE5iQ8llyEb" | |
}, | |
"source": [ | |
"def tokenize_german(text):\r\n", | |
" \"\"\"\r\n", | |
" Tokenizes german text using spacy\r\n", | |
" \"\"\"\r\n", | |
" return [tok.text for tok in spacy_de.tokenizer(text)]\r\n", | |
"\r\n", | |
"\r\n", | |
"def tokenize_english(text):\r\n", | |
" \"\"\"\r\n", | |
" Tokenizes english text using spacy\r\n", | |
" \"\"\"\r\n", | |
" return [tok.text for tok in spacy_en.tokenizer(text)]" | |
], | |
"execution_count": 28, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "GtT-9UtAm3WG" | |
}, | |
"source": [ | |
"SRC = Field(tokenize = tokenize_german, \r\n", | |
" init_token = '<sos>', \r\n", | |
" eos_token = '<eos>', \r\n", | |
" lower = True, \r\n", | |
" batch_first = True)\r\n", | |
"\r\n", | |
"TRG = Field(tokenize = tokenize_english, \r\n", | |
" init_token = '<sos>', \r\n", | |
" eos_token = '<eos>', \r\n", | |
" lower = True, \r\n", | |
" batch_first = True)" | |
], | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Y6A7y_zQoN-y" | |
}, | |
"source": [ | |
"english = data.Field(tokenize_english,\r\n", | |
" lower = True,\r\n", | |
" init_token = \"<sos>\",\r\n", | |
" eos_token=\"<eos>\")\r\n", | |
"\r\n", | |
"german = data.Field(tokenize_german,\r\n", | |
" lower = True,\r\n", | |
" init_token = \"<sos>\",\r\n", | |
" eos_token= \"<eos>\")" | |
], | |
"execution_count": 30, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BLs5syhsqypH" | |
}, | |
"source": [ | |
"train_data, valid_data, test_data = Multi30k.splits(\r\n", | |
" exts=(\".de\",\".en\"), fields = (english,german)\r\n", | |
")" | |
], | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "iyJkcyM5yFoG" | |
}, | |
"source": [ | |
"german.build_vocab(train_data,max_size=10000,min_freq=2)\r\n", | |
"english.build_vocab(train_data,max_size=10000,min_freq=2)" | |
], | |
"execution_count": 32, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nVsdO5CXrM76" | |
}, | |
"source": [ | |
"##Implementing the model\r\n", | |
"class Transformer(nn.Module):\r\n", | |
" def __init__(self,\r\n", | |
" embedding_size,\r\n", | |
" src_vocab_size,\r\n", | |
" trg_vocab_size,\r\n", | |
" src_pad_idx,\r\n", | |
" num_heads,\r\n", | |
" num_encoder_layers,\r\n", | |
" num_decoder_layers,\r\n", | |
" forward_expansion,\r\n", | |
" dropout,\r\n", | |
" max_len,\r\n", | |
" device):\r\n", | |
" super().__init__()\r\n", | |
" self.src_word_embedding = nn.Embedding(src_vocab_size, embedding_size)\r\n", | |
" self.src_position_embedding = nn.Embedding(max_len, embedding_size)\r\n", | |
" self.trg_word_embedding = nn.Embedding(trg_vocab_size, embedding_size)\r\n", | |
" self.trg_position_embedding = nn.Embedding(max_len, embedding_size)\r\n", | |
"\r\n", | |
" self.device = device\r\n", | |
" self.transformer = nn.Transformer(\r\n", | |
" embedding_size,\r\n", | |
" num_heads,\r\n", | |
" num_encoder_layers,\r\n", | |
" num_decoder_layers,\r\n", | |
" forward_expansion,\r\n", | |
" dropout\r\n", | |
" )\r\n", | |
" self.fc_out = nn.Linear(embedding_size,trg_vocab_size)\r\n", | |
" self.dropout = nn.Dropout(dropout)\r\n", | |
" self.src_pad_idx = src_pad_idx\r\n", | |
"\r\n", | |
" def make_src_mask(self,src):\r\n", | |
" src_mask = src.transpose(0,1) == self.src_pad_idx\r\n", | |
" #(N,src_len)\r\n", | |
" return src_mask.to(self.device)\r\n", | |
"\r\n", | |
" def forward(self,src,trg):\r\n", | |
" src_seq_len, N = src.shape\r\n", | |
" trg_seq_len, N = trg.shape\r\n", | |
"\r\n", | |
" src_positions = (\r\n", | |
" torch.arange(0, src_seq_len)\r\n", | |
" .unsqueeze(1)\r\n", | |
" .expand(src_seq_len, N)\r\n", | |
" .to(self.device)\r\n", | |
" )\r\n", | |
"\r\n", | |
" trg_positions = (\r\n", | |
" torch.arange(0, trg_seq_len)\r\n", | |
" .unsqueeze(1)\r\n", | |
" .expand(trg_seq_len, N)\r\n", | |
" .to(self.device)\r\n", | |
" )\r\n", | |
"\r\n", | |
" embed_src = self.dropout(\r\n", | |
" (self.src_word_embedding(src) + self.src_position_embedding(src_positions))\r\n", | |
" )\r\n", | |
" embed_trg = self.dropout(\r\n", | |
" (self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions))\r\n", | |
" )\r\n", | |
"\r\n", | |
" src_padding_mask = self.make_src_mask(src)\r\n", | |
" trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)\r\n", | |
"\r\n", | |
" out = self.transformer(\r\n", | |
" embed_src,\r\n", | |
" embed_trg,\r\n", | |
" src_key_padding_mask=src_padding_mask,\r\n", | |
" tgt_mask=trg_mask,\r\n", | |
" )\r\n", | |
" out = self.fc_out(out)\r\n", | |
" return out" | |
], | |
"execution_count": 33, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pISz5k5V-0lB" | |
}, | |
"source": [ | |
"def save_checkpoint(state, filename=\"my_checkpoint.pth,tar\"):\r\n", | |
" print(\"---->Saving checkpoint\")\r\n", | |
" torch.save(state,filename)\r\n", | |
"\r\n", | |
"def load_checkpoint(stae, model, optimizer):\r\n", | |
" print(\"----->Loading checkpoint\")\r\n", | |
" model.load_state_dict(checkpoint[\"state_dict\"])\r\n", | |
" optimizer.load_state_dict(checkpoint[\"optimizer\"])\r\n", | |
"\r\n", | |
"\r\n", | |
"def translate_sentence(model, sentence, german, english, device, max_length=50):\r\n", | |
" #load german tokenizer\r\n", | |
" spacy_ger = spacy.load(\"de_core_news_sm\")\r\n", | |
" #create tokens in spacy and convert everything into lower case\r\n", | |
" if type(sentence) == str:\r\n", | |
" tokens = [token.text.lower() for token in spacy_ger(sentence)]\r\n", | |
" else:\r\n", | |
" tokens = [token.lower() for token in sentence]\r\n", | |
"\r\n", | |
" #Add <SOS> and <EOS> token in beginning and end\r\n", | |
" tokens.insert(0, german.init_token)\r\n", | |
" tokens.append(german.eos_token)\r\n", | |
"\r\n", | |
" #convert text to indices---->Numericalize them\r\n", | |
" text_to_indices = [german.vocab.stoi[tok] for tok in tokens]\r\n", | |
" #convert to tensors\r\n", | |
" sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)\r\n", | |
" outputs = [english.vocab.stoi[\"<sos>\"]]\r\n", | |
" for i in range(max_length):\r\n", | |
" trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)\r\n", | |
"\r\n", | |
" with torch.no_grad():\r\n", | |
" output = model(sentence_tensor, trg_tensor)\r\n", | |
"\r\n", | |
" best_pred = output.argmax(2)[-1, :].item()\r\n", | |
" outputs.append(best_pred)\r\n", | |
"\r\n", | |
" if best_pred == english.vocab.stoi[\"<eos>\"]:\r\n", | |
" break\r\n", | |
"\r\n", | |
" translated_sentence = [english.vocab.itos[idx] for idx in outputs]\r\n", | |
" #remove start token\r\n", | |
" return translated_sentence[1:]\r\n", | |
"\r\n", | |
"def blue(data, model, german, english, device):\r\n", | |
" target = []\r\n", | |
" outputs = []\r\n", | |
"\r\n", | |
" for example in data:\r\n", | |
" src = vars(example)[\"src\"]\r\n", | |
" trg = vars(example)[\"trg\"]\r\n", | |
"\r\n", | |
" prediction = translate_sentence(model,src,german,english,device)\r\n", | |
" prediction = prediction[:-1] #remove <eos> token\r\n", | |
"\r\n", | |
" target.append([trg])\r\n", | |
" outputs.append(prediction)\r\n", | |
"\r\n", | |
" return blue_score(outputs, targets)" | |
], | |
"execution_count": 34, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "H4mJoG1MyQ8e" | |
}, | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n", | |
"\r\n", | |
"load_model = False\r\n", | |
"save_model = True\r\n", | |
"\r\n", | |
"#Training Hyperparameters\r\n", | |
"n_epochs = 10000\r\n", | |
"learning_rate = 3e-4\r\n", | |
"batch_size = 32\r\n", | |
"\r\n", | |
"#model hyperparameters\r\n", | |
"src_vocab_size = len(german.vocab)\r\n", | |
"trg_vocab_size = len(english.vocab)\r\n", | |
"embedding_size = 512\r\n", | |
"num_heads = 8\r\n", | |
"num_encoder_layers = 3\r\n", | |
"num_decoder_layers = 3\r\n", | |
"dropout = 0.10\r\n", | |
"max_len = 100\r\n", | |
"forward_expansion = 4\r\n", | |
"src_pad_idx = english.vocab.stoi[\"<pad>\"]" | |
], | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xmqgm0DWzq0f" | |
}, | |
"source": [ | |
"train_iterator,valid_iterator,test_iterator = BucketIterator.splits(\r\n", | |
" (train_data,valid_data,test_data),\r\n", | |
" batch_size=batch_size,\r\n", | |
" sort_within_batch=True,\r\n", | |
" sort_key=lambda x: len(x.src),\r\n", | |
" device=device\r\n", | |
")" | |
], | |
"execution_count": 36, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 340 | |
}, | |
"id": "ZTkPS6HS0OJz", | |
"outputId": "75628736-3d92-4a9a-ae49-454a8b94f6fb" | |
}, | |
"source": [ | |
"model = Transformer(\r\n", | |
" embedding_size,\r\n", | |
" src_vocab_size,\r\n", | |
" trg_vocab_size,\r\n", | |
" src_pad_idx,\r\n", | |
" num_heads,\r\n", | |
" num_encoder_layers,\r\n", | |
" num_decoder_layers,\r\n", | |
" forward_expansion,\r\n", | |
" dropout,\r\n", | |
" max_len,\r\n", | |
" device\r\n", | |
").to(device)" | |
], | |
"execution_count": 37, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "RuntimeError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-37-a8b2a1a1c17a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mmax_len\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m ).to(device)\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WvpUgcRz0ekq" | |
}, | |
"source": [ | |
"def count_parameters(model):\r\n", | |
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\r\n", | |
"\r\n", | |
"print(f\"The model has {count_parameters(model):,} trainable parameters\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QvoKGVza0p7d" | |
}, | |
"source": [ | |
"optimizer = optim.AdamW(model.parameters())\r\n", | |
"\r\n", | |
"schdeuler = optim.lr_scheduler.ReduceLROnPlateau(\r\n", | |
" optimizer, factor=0.1,patience=10, verbose=True\r\n", | |
")\r\n", | |
"pad_idx = english.vocab.stoi[\"<pad>\"]\r\n", | |
"criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "27vttEJ--r2t" | |
}, | |
"source": [ | |
"if load_model:\r\n", | |
" load_checkpoint(torch.load(\"my_checkpoint.pth.tar\"),model,optimizer)\r\n", | |
"\r\n", | |
"sentence = \"ein pferd geht unter einer brucke neben einem boot\"\r\n", | |
"\r\n", | |
"for epoch in range(n_epochs):\r\n", | |
" print(f\"[Epoch {epoch}/{n_epochs}\")\r\n", | |
"\r\n", | |
" if save_model:\r\n", | |
" checkpoint = {\r\n", | |
" \"state_dict\": model.state_dict(),\r\n", | |
" \"optimizer\" : optimizer.state_dict(),\r\n", | |
" }\r\n", | |
" save_checkpoint(checkpoint)\r\n", | |
"\r\n", | |
" model.eval()\r\n", | |
" translated_sentence = translate_sentence(\r\n", | |
" model, sentence, german, english, device, max_length=50\r\n", | |
" )\r\n", | |
" print(f\"Translated example sentence:\\n {translated_sentence}\")\r\n", | |
" model.train()\r\n", | |
" losses = []\r\n", | |
"\r\n", | |
" for batch_idx, data in enumerate(train_iterator):\r\n", | |
" #send inputs and targets to cuda\r\n", | |
" inp = data.src.to(device)\r\n", | |
" target = data.trg.to(device)\r\n", | |
" #forward prop\r\n", | |
" output = model(inp, target)\r\n", | |
"\r\n", | |
" #ouput shape--->[trg_len,batch_size,output_dim]\r\n", | |
" #reshape it appropriately for cross entropy loss\r\n", | |
" output = output.reshape(-1, output.shape[2])\r\n", | |
" target = target[1:].reshape(-1)\r\n", | |
"\r\n", | |
" optimizer.zero_grad\r\n", | |
" loss = criterion(output,target)\r\n", | |
" losses.append(loss.item())\r\n", | |
" #back prop\r\n", | |
" loss.backward()\r\n", | |
" #clip to avoid exploding gradients\r\n", | |
" torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1)\r\n", | |
" #gradient descent step\r\n", | |
" optimizer.step()\r\n", | |
"\r\n", | |
" mean_loss = sum(losses)/len(losses)\r\n", | |
" schdeuler.step(mean_loss)\r\n", | |
"\r\n", | |
"score = bleu(test_data[1:100], model, german, english, device)\r\n", | |
"print(f\"Blue score {score*100:.2f}\") " | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4Dzu7EtTCbbh" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment