Last active
September 23, 2022 09:43
-
-
Save Createdd/69da98fe885d034fc459f62922d5ba72 to your computer and use it in GitHub Desktop.
text_processing_lstm.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/Createdd/69da98fe885d034fc459f62922d5ba72/text_processing_lstm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Disclaimer\n", | |
"\n", | |
"I wrote this during the course of my studies within the AI Master's degree at the JKU Linz in Austria. \n", | |
"\n", | |
"This was one of the exercises that we needed to work on. However, due to copyright I re-wrote the instructions. Nevertheless, I want to give credit to the institute for coming up with the idea of this assignment within the study programm. Feel free to check out this [study program](https://www.jku.at/en/degree-programs/types-of-degree-programs/masters-degree-programs/ma-artificial-intelligence/) if you are interested. I think it is an amazing program. \n", | |
"\n", | |
"I am not associated with the institute and this does not reflect the quality of the study program. These are also just parts of the original exercise. Those are my elaborations and also not the best possible solutions on this topic. \n", | |
"\n", | |
"\n", | |
"---\n", | |
"\n", | |
"\n", | |
"Furthermore, I will not provide the text dataset that was used. However, I can say that it is quite easy to get Trump speech text. \n", | |
"\n", | |
"\n", | |
"--- \n" | |
], | |
"metadata": { | |
"id": "b83dOFUNDxWv" | |
}, | |
"id": "b83dOFUNDxWv" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Introduction -> Text processing with LSTM and PyTorch\n", | |
"\n", | |
"The goal is to train a LSTM model to generate some text.\n", | |
"\n", | |
"Therefore we need to represent text for the network to learn. Here we use character embedding. We need to \n", | |
"- define an alphabet (a set of characters)\n", | |
"- define the position of the character in the alphabet we want to represent \n", | |
"- let the neural network add weights of the positions of the character\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "eiTxHpMT_1Gm" | |
}, | |
"id": "eiTxHpMT_1Gm" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 1. Encode Characters\n", | |
"\n", | |
"Encode characters by transforming them to tensors and vice-versa" | |
], | |
"metadata": { | |
"id": "lYeW5oLz_x2W" | |
}, | |
"id": "lYeW5oLz_x2W" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "17f16ae5", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "e858a461-029f-435c-ed51-c186f50bfef8" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"cuda:0 True\n" | |
] | |
} | |
], | |
"source": [ | |
"import re\n", | |
"import torch\n", | |
"\n", | |
"ALL_CHARS_NUMS = 'abcdefghijklmnopqrstuvwxyz0123456789 .!?'\n", | |
"\n", | |
"# Checks for cpu and gpu setup\n", | |
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | |
"print(device, torch.cuda.is_available())\n", | |
"\n", | |
"class Encoder:\n", | |
" def __init__(self, all_chars):\n", | |
" self.encode_map = {i : char for char, i in enumerate(all_chars)}\n", | |
" self.decode_map = {char : i for char, i in enumerate(all_chars)}\n", | |
" \n", | |
" def __call__(self, input):\n", | |
" # If the input text contains characters that are not in the alphabet, \n", | |
" # then __call__ should either remove them or map them to a corresponding character that belongs to the alphabet.\n", | |
" if type(input) is str:\n", | |
" encoded = []\n", | |
" for char in input:\n", | |
" if char in self.encode_map:\n", | |
" encoded.append(self.encode_map[char])\n", | |
"\n", | |
" # we use torch of data type long because otherwise errors in the later calcualtions happen\n", | |
" encode_res = torch.tensor(encoded, dtype=torch.long).to(device)\n", | |
" \n", | |
" return encode_res\n", | |
"\n", | |
" # If the argument is a torch.Tensor, then the method should return \n", | |
" # a string representation of the input, i.e. it should function as decoder.\n", | |
" else:\n", | |
" sring_rep = ''\n", | |
" for i in input:\n", | |
" char = self.decode_map[i.item()]\n", | |
" sring_rep += char\n", | |
"\n", | |
" return sring_rep" | |
], | |
"id": "17f16ae5" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wTQG93vJ4xs7", | |
"outputId": "cb32c65c-af3c-4653-ad7d-68d6444f4bd5" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Encode ab22a = tensor([ 0, 1, 28, 28, 0], device='cuda:0')\n", | |
"Decode tensor([ 0, 1, 28, 28, 0], device='cuda:0') = ab22a\n" | |
] | |
} | |
], | |
"source": [ | |
"# test encoder class\n", | |
"encode = Encoder(ALL_CHARS_NUMS)\n", | |
"sample = 'ab22a'\n", | |
"test_encode = encode(sample)\n", | |
"test_decode = encode(test_encode)\n", | |
"\n", | |
"print(f'Encode {sample} = ', test_encode)\n", | |
"print(f'Decode {test_encode} = ', test_decode)\n", | |
"\n", | |
"if test_decode != sample:\n", | |
" raise AssertionError('Encoder not working')" | |
], | |
"id": "wTQG93vJ4xs7" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 2. Create Dataset from text file\n", | |
"\n", | |
"We want to create the functionality to provide a text file, consider it as a long sequence of characters, and return the character and its position in the provided file." | |
], | |
"metadata": { | |
"id": "NcsbGmhsEKJ-" | |
}, | |
"id": "NcsbGmhsEKJ-" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "9df917ce" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import Dataset\n", | |
"\n", | |
"class TextDataset(torch.utils.data.Dataset):\n", | |
" def __init__(self, path, l):\n", | |
" # load text data from specified path\n", | |
" self.l = l\n", | |
" f = open(path, \"r\", encoding=\"utf-8\")\n", | |
" data_str = f.read().lower()\n", | |
"\n", | |
" # encode text with previously defined function\n", | |
" encoder = Encoder(ALL_CHARS_NUMS)\n", | |
" data = encoder(data_str).to(device)\n", | |
"\n", | |
" # split to pre-defined sequence length\n", | |
" # Splits the tensor into chunks. Each chunk is a view of the original tensor.\n", | |
" data = torch.split(data, l)\n", | |
"\n", | |
" # concatinate the sequences of data in torch way and assign to data var\n", | |
" # torch stack Concatenates a sequence of tensors along a new dimension.\n", | |
" self.data = torch.stack(data[:-1]) if len(data[-1]) < l else torch.stack(data)\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" def __getitem__(self, i):\n", | |
" return self.data[i]" | |
], | |
"id": "9df917ce" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ChxpwLlf68XD", | |
"outputId": "0ac66d84-c167-4892-d517-bda07f8cad98" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Encoded data sample:\n", | |
" tensor([18, 15, 4, 4, 2, 7, 36, 27, 37, 37, 37, 19, 7, 0, 13, 10, 36, 24,\n", | |
" 14, 20, 36, 18, 14, 36, 12, 20, 2, 7, 37, 36, 36, 19, 7, 0, 19, 18,\n", | |
" 36, 18, 14, 36, 13, 8, 2, 4, 37, 36, 36, 8, 18, 13, 19, 36, 7, 4,\n", | |
" 36, 0, 36, 6, 17, 4, 0, 19, 36, 6, 20, 24, 37, 36, 36, 7, 4, 36,\n", | |
" 3, 14, 4, 18, 13, 19, 36, 6, 4, 19, 36, 0, 36, 5, 0, 8, 17, 36,\n", | |
" 15, 17, 4, 18, 18, 36, 7, 4, 36, 3], device='cuda:0')\n", | |
"Sample text: speech 1...thank you so much. thats so nice. isnt he a great guy. he doesnt get a fair press he d\n", | |
"length of text 100\n", | |
"length of encoded 8422\n" | |
] | |
} | |
], | |
"source": [ | |
"# test encoded dataset\n", | |
"text_data_encoded = TextDataset('data/trump_train.txt', l=100)\n", | |
"print('Encoded data sample:\\n', text_data_encoded[0])\n", | |
"print(f'Sample text: {encode(text_data_encoded[0])}')\n", | |
"print(f'length of text {encode(text_data_encoded[0]).__len__()}')\n", | |
"print(f'length of encoded {text_data_encoded.__len__()}')" | |
], | |
"id": "ChxpwLlf68XD" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 3. LSTM Model \n", | |
"\n", | |
"We create a module that consists of \n", | |
"- an Embeddung layer that maps the alphabet to the embeddings\n", | |
"- an LSTM layer the maps the embeddings to the hidden states\n", | |
"- a linear layer that maps the hidden states back to the alphabet\n", | |
"\n", | |
"In the forward pass the input sequence results in the logits." | |
], | |
"metadata": { | |
"id": "eYM_Wr6Vx8OU" | |
}, | |
"id": "eYM_Wr6Vx8OU" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "459fe907" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class NextCharLSTM(nn.Module):\n", | |
" def __init__(self, alphabet_size, embedding_dim, hidden_dim):\n", | |
" super(NextCharLSTM, self).__init__()\n", | |
"\n", | |
" # define layers as instructed\n", | |
" # A simple lookup table that stores embeddings of a fixed dictionary and size.\n", | |
" # This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.\n", | |
" self.embeddings = nn.Embedding(alphabet_size, embedding_dim)\n", | |
"\n", | |
" # pytorch LSTM module by default expects non-batch first input, the batch size was expected to be at dimension index 1, \n", | |
" # so that is why I set batch_first=True\n", | |
" # Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.\n", | |
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n", | |
"\n", | |
" # Applies a linear transformation to the incoming data\n", | |
" self.linear = nn.Linear(hidden_dim, alphabet_size)\n", | |
"\n", | |
" # utilize GPUs\n", | |
" self.embeddings.to(device)\n", | |
" self.lstm.to(device)\n", | |
" self.linear.to(device)\n", | |
"\n", | |
"\n", | |
"\n", | |
" def forward(self, inputs):\n", | |
" embeds = self.embeddings(inputs)\n", | |
"\n", | |
" #Outputs: output, (h_n, c_n)\n", | |
" output, hidden = self.lstm(embeds)\n", | |
"\n", | |
" # add the hidden states of all activations and not only the outout layer\n", | |
" logits = self.linear(output)\n", | |
" return logits\n", | |
" " | |
], | |
"id": "459fe907" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 4. We write an epoch function that\n", | |
"- validates the model if no optimizer is given\n", | |
"- trains the model in many-to-many setting is optimizer is given\n", | |
"- per mini-beach a validation/training erpoch shall be performed\n" | |
], | |
"metadata": { | |
"id": "lpoDSoyl0B7R" | |
}, | |
"id": "lpoDSoyl0B7R" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "13f33250" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.data import DataLoader\n", | |
"import numpy as np\n", | |
"\n", | |
"\n", | |
"# define an epoch function that depends on a textloader, the lstm model, an optimizer and \n", | |
"# the fact if we would need to train many2one setting (because this is the bonus task further down)\n", | |
"\n", | |
"def epoch(data_loader, lstm_model, optimizer, many_to_one=False):\n", | |
" # define loss function and loss var\n", | |
" loss_function = torch.nn.CrossEntropyLoss()\n", | |
" loss_function.to(device)\n", | |
" batch_losses = []\n", | |
"\n", | |
" # if an optimizer is set then there is training mode and not validation (as instructed) \n", | |
" # train() sets the modules in the network in training mode. \n", | |
" # It tells our model that we are currently in the training phase so the model keeps some layers, like dropout, batch-normalization which behaves differently depends on the current phase, active. \n", | |
" # whereas the model.eval() does the opposite\n", | |
" train_mode = optimizer is not None\n", | |
" if train_mode:\n", | |
" lstm_model.train()\n", | |
" else:\n", | |
" lstm_model.eval()\n", | |
"\n", | |
"\n", | |
" for i, batch in enumerate(data_loader):\n", | |
" # In PyTorch, for every mini-batch during the training phase, we typically want to explicitly set the gradients to zero before starting to do backpropragation \n", | |
" # (i.e., updating the Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes. \n", | |
" if train_mode:\n", | |
" optimizer.zero_grad()\n", | |
" \n", | |
" with torch.set_grad_enabled(train_mode):\n", | |
" logits = lstm_model(batch[:,:-1])\n", | |
"\n", | |
" # differantiate between many2one as reshaping is necessary,\n", | |
" # namely taking the last dimension instead of all for the loss\n", | |
" # transpose returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.\n", | |
" # this needs to be done for the calculation of the cross entropy loss function (see pytorch docs)\n", | |
" if many_to_one:\n", | |
" x = torch.transpose(logits, 1, 2)[:,:,-1]\n", | |
" y = batch[:,-1]\n", | |
" else:\n", | |
" x = torch.transpose(logits, 1, 2)\n", | |
" y = batch[:,1:]\n", | |
"\n", | |
" loss = loss_function(x, y)\n", | |
" \n", | |
" batch_losses.append(loss.item())\n", | |
" \n", | |
" # accumulate gradients and update parameters for training\n", | |
" # optimizer.step performs a single optimization step (parameter update).\n", | |
" if train_mode:\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" all_losses = np.array(batch_losses)\n", | |
" return all_losses" | |
], | |
"id": "13f33250" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 5. Train model and visualize results\n", | |
"\n", | |
"Putting everything together with a pre-defined set of hyperparameters.\n", | |
"We will validate on a separate dataset to see if the learning actually generalizes well." | |
], | |
"metadata": { | |
"id": "t_N6JpfG5SYg" | |
}, | |
"id": "t_N6JpfG5SYg" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "AQxsNd7WeYsy" | |
}, | |
"outputs": [], | |
"source": [ | |
"# For showing the results of the learning we will use the generate_text function that is developed further down in the notebook\n", | |
"# This shall show us if the training not only makes sense in terms of decreasing loss but also in terms of text results\n", | |
"\n", | |
"####################################################\n", | |
"# Copied from Exercise 8\n", | |
"####################################################\n", | |
"from torch.distributions import Categorical\n", | |
"\n", | |
"def generate_text(seed_text, encoder, lstm_model, text_length, top_k_characters=1):\n", | |
" # set up model for evaluation\n", | |
" lstm_model.eval()\n", | |
" result = encoder(seed_text.lower())\n", | |
" \n", | |
" # disable grad computations\n", | |
" with torch.no_grad():\n", | |
" # predict for each character the last topk character\n", | |
" for i in range(text_length):\n", | |
" logits = lstm_model(result.view(1, -1))\n", | |
" # use softmax for proper topk computation\n", | |
" softmax = torch.nn.functional.softmax(logits, dim=2)\n", | |
" topk = torch.topk(softmax, top_k_characters, 2)\n", | |
" \n", | |
" # create probabilistic distribution\n", | |
" categorical = Categorical(topk.values[:,-1:])\n", | |
" sample = categorical.sample()\n", | |
"\n", | |
" # concatinate the results together\n", | |
" result = torch.cat((result, topk.indices[0, -1, sample].view(-1)))\n", | |
" \n", | |
" return encoder(result)\n", | |
"####################################################" | |
], | |
"id": "AQxsNd7WeYsy" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8987ae83", | |
"outputId": "88ce0236-1099-441d-ae73-b064d57a420d" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Initially:\n", | |
"-> america is n.n.bn1yn1n.yln11ny.y.lm1an.nb.nnbn1nnb1yl1n.y..ylnb..19lqmymymlal11y.nnb.y111919ymmabn.lqnyn1.n1ynl \n", | |
" ----------------------------------------------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import os\n", | |
"from tqdm import trange\n", | |
"\n", | |
"\n", | |
"# Set all parameters for training the LSTM network\n", | |
"sequence_length = 100\n", | |
"batch_size = 256\n", | |
"embedding_dim = 8\n", | |
"hidden_dim = 512\n", | |
"learning_rate = 1e-3\n", | |
"num_epochs = 50\n", | |
"\n", | |
"start_text = 'America is '\n", | |
"best_valid_loss = None\n", | |
"many_to_one = False\n", | |
"output_file = \"best_model_m2m.pt\"\n", | |
"encoder = Encoder(ALL_CHARS_NUMS)\n", | |
"\n", | |
"train_data = TextDataset(os.path.join('data/', 'trump_train.txt'), l=sequence_length)\n", | |
"valid_data = TextDataset(os.path.join('data/', 'trump_val.txt'), l=sequence_length)\n", | |
"train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)\n", | |
"valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=True)\n", | |
"\n", | |
"all_chars_nums_length = len(ALL_CHARS_NUMS)\n", | |
"m2m_model = NextCharLSTM(all_chars_nums_length, embedding_dim, hidden_dim).to(device)\n", | |
"optim = torch.optim.Adam(params=m2m_model.parameters(), lr=learning_rate)\n", | |
"\n", | |
"generated_text = generate_text(start_text, encoder, m2m_model, 100, 4)\n", | |
"print(f'Initially:\\n-> {generated_text} \\n {\"-\"*100}')" | |
], | |
"id": "8987ae83" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "LfTsh39YqCxG" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Define training by model, corresponding loaders, optimizer, epochs, oputput and if many2one setting\n", | |
"def train(model, train_loader, valid_loader, optimizer, num_epochs, output_file, many_to_one):\n", | |
" train_losses = []\n", | |
" valid_losses = []\n", | |
" \n", | |
" for i in range(num_epochs):\n", | |
" #calculate losses in pre-defined setting\n", | |
" train_loss = epoch(train_loader, model, optimizer, many_to_one)\n", | |
" train_losses.append(np.mean(train_loss).item())\n", | |
"\n", | |
" valid_loss = epoch(valid_loader, model, None, many_to_one)\n", | |
" valid_losses.append(np.mean(valid_loss).item())\n", | |
"\n", | |
" print(f'epoch {i}\\ntrain loss: {train_losses[-1]}, validation loss: {valid_losses[-1]}/n')\n", | |
"\n", | |
" # store the best performance in pre-defined file by taking the lowest score\n", | |
" # if only one is stored, take that one, else take the smalles\n", | |
" if len(valid_losses) == 1 or valid_losses[-1] < min(valid_losses[:-1]):\n", | |
" torch.save(model.state_dict(), output_file)\n", | |
"\n", | |
" # every 20th iteration generate text to see improvements\n", | |
" if i % 20 == 0:\n", | |
" generated_text = generate_text(start_text, encoder, m2m_model, 100, 4)\n", | |
" print(f'Improved text:\\n-> {generated_text} \\n {\"-\"*20}')\n", | |
" \n", | |
" return train_losses, valid_losses" | |
], | |
"id": "LfTsh39YqCxG" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Gn1VQ8TVKif6", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "3142fcf1-084e-452a-c2c3-3ce14e6c8fdf" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"epoch 0\n", | |
"train loss: 3.0722065260916045, validation loss: 2.8938543796539307/n\n", | |
"Improved text:\n", | |
"-> america is to te ee et atee ae eta e at taa aaa t e a ee t e to t eetea et etee a a e \n", | |
" --------------------\n", | |
"epoch 1\n", | |
"train loss: 2.847360726558801, validation loss: 2.778059959411621/n\n", | |
"epoch 2\n", | |
"train loss: 2.660712401072184, validation loss: 2.6005918979644775/n\n", | |
"epoch 3\n", | |
"train loss: 2.496720465746793, validation loss: 2.477595806121826/n\n", | |
"epoch 4\n", | |
"train loss: 2.3691409934650767, validation loss: 2.3729653358459473/n\n", | |
"epoch 5\n", | |
"train loss: 2.25620529868386, validation loss: 2.266242742538452/n\n", | |
"epoch 6\n", | |
"train loss: 2.1481954328941577, validation loss: 2.1730096340179443/n\n", | |
"epoch 7\n", | |
"train loss: 2.0482463403181597, validation loss: 2.0945851802825928/n\n", | |
"epoch 8\n", | |
"train loss: 1.9594271616502241, validation loss: 2.028244972229004/n\n", | |
"epoch 9\n", | |
"train loss: 1.8832954240567756, validation loss: 1.9642455577850342/n\n", | |
"epoch 10\n", | |
"train loss: 1.8153246207670732, validation loss: 1.9098560810089111/n\n", | |
"epoch 11\n", | |
"train loss: 1.7568744529377331, validation loss: 1.857060194015503/n\n", | |
"epoch 12\n", | |
"train loss: 1.7044664765849258, validation loss: 1.8210113048553467/n\n", | |
"epoch 13\n", | |
"train loss: 1.6578938888780999, validation loss: 1.7762720584869385/n\n", | |
"epoch 14\n", | |
"train loss: 1.6164502555673772, validation loss: 1.7423654794692993/n\n", | |
"epoch 15\n", | |
"train loss: 1.5778644699038882, validation loss: 1.7043498754501343/n\n", | |
"epoch 16\n", | |
"train loss: 1.5420833826065063, validation loss: 1.6718835830688477/n\n", | |
"epoch 17\n", | |
"train loss: 1.5101951974810977, validation loss: 1.6475307941436768/n\n", | |
"epoch 18\n", | |
"train loss: 1.480200590509357, validation loss: 1.621835708618164/n\n", | |
"epoch 19\n", | |
"train loss: 1.4530761169664788, validation loss: 1.5967477560043335/n\n", | |
"epoch 20\n", | |
"train loss: 1.4271232214840976, validation loss: 1.5742368698120117/n\n", | |
"Improved text:\n", | |
"-> america is one to git. they say thank you.w. and you know there. tougs. we will not to did that are trodu lige \n", | |
" --------------------\n", | |
"epoch 21\n", | |
"train loss: 1.4043170141451287, validation loss: 1.5603150129318237/n\n", | |
"epoch 22\n", | |
"train loss: 1.382313695820895, validation loss: 1.5382187366485596/n\n", | |
"epoch 23\n", | |
"train loss: 1.3619441986083984, validation loss: 1.5201952457427979/n\n", | |
"epoch 24\n", | |
"train loss: 1.3428273778973203, validation loss: 1.5054054260253906/n\n", | |
"epoch 25\n", | |
"train loss: 1.3253240007342715, validation loss: 1.49391508102417/n\n", | |
"epoch 26\n", | |
"train loss: 1.3081143877723, validation loss: 1.4762858152389526/n\n", | |
"epoch 27\n", | |
"train loss: 1.292138959422256, validation loss: 1.4650746583938599/n\n", | |
"epoch 28\n", | |
"train loss: 1.2775824106100835, validation loss: 1.4516977071762085/n\n", | |
"epoch 29\n", | |
"train loss: 1.2638496991359827, validation loss: 1.4478967189788818/n\n", | |
"epoch 30\n", | |
"train loss: 1.2505281007651128, validation loss: 1.4315167665481567/n\n", | |
"epoch 31\n", | |
"train loss: 1.2370703148119377, validation loss: 1.4247608184814453/n\n", | |
"epoch 32\n", | |
"train loss: 1.2247026400132612, validation loss: 1.4173210859298706/n\n", | |
"epoch 33\n", | |
"train loss: 1.213089599753871, validation loss: 1.4089386463165283/n\n", | |
"epoch 34\n", | |
"train loss: 1.2022016734787913, validation loss: 1.3989789485931396/n\n", | |
"epoch 35\n", | |
"train loss: 1.1902846350814358, validation loss: 1.3936214447021484/n\n", | |
"epoch 36\n", | |
"train loss: 1.1808778155933728, validation loss: 1.3870627880096436/n\n", | |
"epoch 37\n", | |
"train loss: 1.1707697853897556, validation loss: 1.380878210067749/n\n", | |
"epoch 38\n", | |
"train loss: 1.1600048289154514, validation loss: 1.373460292816162/n\n", | |
"epoch 39\n", | |
"train loss: 1.1505673365159468, validation loss: 1.3698058128356934/n\n", | |
"epoch 40\n", | |
"train loss: 1.1412737514033462, validation loss: 1.363592267036438/n\n", | |
"Improved text:\n", | |
"-> america is an interests. and i watced the poll that he didnt were the reason i cant do their family. im not a c \n", | |
" --------------------\n", | |
"epoch 41\n", | |
"train loss: 1.1321798707499648, validation loss: 1.3571323156356812/n\n", | |
"epoch 42\n", | |
"train loss: 1.1239526813680476, validation loss: 1.3577945232391357/n\n", | |
"epoch 43\n", | |
"train loss: 1.1159283645225293, validation loss: 1.3480942249298096/n\n", | |
"epoch 44\n", | |
"train loss: 1.1068917910257976, validation loss: 1.347411870956421/n\n", | |
"epoch 45\n", | |
"train loss: 1.09801472678329, validation loss: 1.3405795097351074/n\n", | |
"epoch 46\n", | |
"train loss: 1.0899495283762615, validation loss: 1.3408769369125366/n\n", | |
"epoch 47\n", | |
"train loss: 1.081646922862891, validation loss: 1.3338264226913452/n\n", | |
"epoch 48\n", | |
"train loss: 1.0735088081070872, validation loss: 1.3323742151260376/n\n", | |
"epoch 49\n", | |
"train loss: 1.0657582355268074, validation loss: 1.3278988599777222/n\n" | |
] | |
} | |
], | |
"source": [ | |
"train_losses, valid_losses = train(m2m_model, train_loader, valid_loader, optim, num_epochs, output_file, many_to_one)" | |
], | |
"id": "Gn1VQ8TVKif6" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 281 | |
}, | |
"id": "f7M5Et3PuEaa", | |
"outputId": "0810f61a-6c0f-4ea4-af9d-e01df7e56057" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhV1dX48e/KPM8hCQkhzBBmiIiiDKKIE86z1morr0Ortg5FX6t9bX1tf761ra0TWlptVVTUOqGCLQLKoEFmkDlAgJABMpKEDOv3xzmBAAlJIOEmN+vzPOe59+59zr3raFh333322VtUFWOMMd7Lx9MBGGOMaVuW6I0xxstZojfGGC9nid4YY7ycJXpjjPFyluiNMcbLWaI3xhgvZ4nedAgikiUi53o6DmM6Ikv0xpwCIuLn6RhM52WJ3nRoIhIoIn8Ukd3u9kcRCXTr4kTkYxEpFJF9IrJQRHzcul+IyC4RKRGRDSIysZH3DxaR34vIdhEpEpGv3LLxIpJ91L6HfnWIyK9EZJaI/FNEioFHRKRcRGLq7T9cRPJFxN99fZuIrBeR/SLyuYh0d8tFRP4gIrkiUiwiq0VkUJv8BzVeyRK96ej+GxgNDAOGAqOAR926+4FsIB5IAB4BVET6AT8BTlPVcOB8IKuR9/8/YCRwJhADPATUNjO2S4FZQBTwNLAYuLJe/Q3ALFWtEpFL3fiucONdCLzp7jcJGAv0BSKBa4CCZsZgjCV60+HdCDyhqrmqmgf8D3CzW1cFJAHdVbVKVReqM7lTDRAIpIuIv6pmqeqWo9/Ybf3fBtyrqrtUtUZVF6lqZTNjW6yq/1LVWlUtB94ArnffW4Dr3DKAO4CnVHW9qlYD/wsMc1v1VUA40B8Qd589LfvPZDozS/Smo+sKbK/3ertbBk4rejMwR0S2isg0AFXdDNwH/ArIFZGZItKVY8UBQcAxXwLNtPOo1+8CZ4hIEk4LvRan5Q7QHfiT281UCOwDBEhW1f8AfwGec+OdLiIRJxiT6YQs0ZuObjdOkqyT6pahqiWqer+q9gSmAD+v64tX1TdU9Sz3WAV+18B75wMVQK8G6sqAkLoXIuKL0+VS3xFTw6rqfmAOcC1Ot81MPTx97E7gv1Q1qt4WrKqL3GOfVdWRQDpOF86Dx/uPYkx9luhNR+IvIkH1Nj+cfuxHRSReROKAx4B/AojIxSLS2+0mKcLpsqkVkX4ico570bYCKKeBfndVrQVmAM+ISFcR8RWRM9zjNgJBInKRezH1UZzuoKa8AfwAuIrD3TYALwIPi8hAN/ZIEbnafX6aiJzufk6ZG3NzrxMYY4nedCizcZJy3fYr4DdAJrAKWA1855YB9AG+AEpxLoQ+r6rzcBLyb3Fa7DlAF+DhRj7zAfd9v8XpTvkd4KOqRcBdwCvALpwEnN3Ie9T3oRtXjqqurCtU1ffd957pjtJZA1zgVkcALwP7cbqmCnC6pYxpFrGFR4wxxrtZi94YY7ycJXpjjPFyluiNMcbLWaI3xhgv1y4nWoqLi9O0tDRPh2GMMR3GsmXL8lX16Hs5gHaa6NPS0sjMzPR0GMYY02GIyPbG6qzrxhhjvJwlemOM8XKW6I0xxsu1yz56Y4xpiaqqKrKzs6moqPB0KG0uKCiIlJQU/P39m32MJXpjTIeXnZ1NeHg4aWlpOHPYeSdVpaCggOzsbHr06NHs46zrxhjT4VVUVBAbG+vVSR5ARIiNjW3xLxdL9MYYr+DtSb7OiZyn1yT6iqoaXpq/ha825Xs6FGOMaVe8JtEH+PowfcFWZi07evU2Y4xpW4WFhTz//PMtPu7CCy+ksLCwDSI6ktckeh8fYWzfeBZsyqe21ubYN8acOo0l+urq6uMeN3v2bKKiotoqrEO8JtEDjOsbz76yg6zZXeTpUIwxnci0adPYsmULw4YN47TTTuPss89mypQppKenA3DZZZcxcuRIBg4cyPTp0w8dl5aWRn5+PllZWQwYMIDbb7+dgQMHMmnSJMrLy1stPq8aXnl2nzhEYP6GPIaktP23pDGm/fmfj9aybndxq75netcIHr9kYKP1v/3tb1mzZg0rVqzgyy+/5KKLLmLNmjWHhkDOmDGDmJgYysvLOe2007jyyiuJjY094j02bdrEm2++ycsvv8w111zDu+++y0033dQq8TfZoncXYf5GRFaKyFoR+Z8G9gkUkbdEZLOILBWRtHp1D7vlG0Tk/FaJuhGxYYEM6hrJ/I15bfkxxhhzXKNGjTpinPuzzz7L0KFDGT16NDt37mTTpk3HHNOjRw+GDRsGwMiRI8nKymq1eJrToq8EzlHVUncV+q9E5FNVXVJvnx8B+1W1t4hch7PI8bUikg5cBwwEugJfiEhfVa1ptTM4yri+8bwwfwtF5VVEBjf/zjFjjHc4Xsv7VAkNDT30/Msvv+SLL75g8eLFhISEMH78+AbHwQcGBh567uvr26pdN0226NVR6r70d7ejr3ZeCrzqPp8FTBRnsOelwExVrVTVbcBmYFSrRN6Icf3iqalVFm22YZbGmFMjPDyckpKSBuuKioqIjo4mJCSE77//niVLljS4X1tq1sVYEfEVkRVALjBXVZcetUsysBNAVauBIiC2frkr2y1r6DOmikimiGTm5Z1418vwblGEB/lZ940x5pSJjY1lzJgxDBo0iAcffPCIusmTJ1NdXc2AAQOYNm0ao0ePPuXxNetirNvVMkxEooD3RWSQqq5pzUBUdTowHSAjI+OEx0f6+fpwVu845m/MQ1U7zd1yxhjPeuONNxosDwwM5NNPP22wrq4fPi4ujjVrDqfUBx54oFVja9HwSlUtBOYBk4+q2gV0AxARPyASKKhf7kpxy9rU2L7x7CmqYFNuadM7G2OMl2vOqJt4tyWPiAQD5wHfH7Xbh8At7vOrgP+oqrrl17mjcnoAfYBvWiv4xozt6yybuMC6b4wxplkt+iRgnoisAr7F6aP/WESeEJEp7j5/BWJFZDPwc2AagKquBd4G1gGfAXe35YibOslRwfTpEmb99MYYQzP66FV1FTC8gfLH6j2vAK5u5PgngSdPIsYTMq5vPK8t3s6Bg9WEBHjVfWHGGNMiXjUFQn3j+sVzsKaWpVv3eToUY4zxKK9N9KelxRDk72PdN8aYTs9rE32Qvy+je8baBVljTLsUFhYGwO7du7nqqqsa3Gf8+PFkZmae9Gd5baIHp59+a34ZOwoOeDoUY4xpUNeuXZk1a1abfoZ3JfrSPNj870Mvx7nDLOdvsla9MaZtTZs2jeeee+7Q61/96lf85je/YeLEiYwYMYLBgwfzwQcfHHNcVlYWgwYNAqC8vJzrrruOAQMGcPnll7fafDfeNRzl04dg8xdwzwoIjaVHXCjdYoKZvyGPm0d393R0xphT5W8XHVs28DIYdTscPACvNzBIcNgNMPxGKCuAt39wZN2tnzT5kddeey333Xcfd999NwBvv/02n3/+Offccw8RERHk5+czevRopkyZ0ugd+y+88AIhISGsX7+eVatWMWLEiCY/tzm8q0U/fhocLIUFTwPOIrrj+sazaEs+B6trPRycMcabDR8+nNzcXHbv3s3KlSuJjo4mMTGRRx55hCFDhnDuueeya9cu9u7d2+h7LFiw4NAc9EOGDGHIkCGtEpt3tejj+8Hwm+HbV+D0/4KYHoztE88/l+wgc/s+zuwV5+kIjTGnwvFa4AEhx68PjW1WC74hV199NbNmzSInJ4drr72W119/nby8PJYtW4a/vz9paWkNTlHc1ryrRQ8w/mHw8YP//BqAM3vH4ecjLNho0xYbY9rWtddey8yZM5k1axZXX301RUVFdOnSBX9/f+bNm8f27duPe/zYsWMPTY62Zs0aVq1a1SpxeV+ij0iCM38KqlBTTVigH6elxfDF+sZ/LhljTGsYOHAgJSUlJCcnk5SUxI033khmZiaDBw/mtddeo3///sc9/s4776S0tJQBAwbw2GOPMXLkyFaJy7u6bupMeATqXey4cHAiv/xgLRtySuiXGO7BwIwx3m716tWHnsfFxbF48eIG9ystdWbXTUtLOzRFcXBwMDNnzmz1mLyvRQ+Hk3zuesjOZPKgJHwEPlm127NxGWOMB3hnogeorYW3boYPf0p8qB+je8by8eo9OLMnG2NM5+G9id7Hx+nCyV0HK2dy0ZAktuaVsX5Pw+s6GmM6ts7SiDuR8/TeRA8w8HLoOgLmPcnkvpH4+gifrLbuG2O8TVBQEAUFBV6f7FWVgoICgoKCWnScd16MrSMC5z0Br15M7Nq/cWavM/l41R4emNTP1pI1xoukpKSQnZ1NXp73T3cSFBRESkpKi45pMtGLSDfgNSABUGC6qv7pqH0eBG6s954DgHhV3SciWUAJUANUq2pGiyI8WT3OhkFXQVUFFw1OYtp7q1m7u5hByZGnNAxjTNvx9/enR48eng6j3WpO1001cL+qpgOjgbtFJL3+Dqr6tKoOU9VhwMPAfFWtv+LHBLf+1Cb5Olf9FSY8zPkDE/HzET6y0TfGmE6kyUSvqntU9Tv3eQmwHkg+ziHXA2+2TnitSJXog7sZ0zuOT1bZ6BtjTOfRoouxIpKGs37s0kbqQ4DJwLv1ihWYIyLLRGTqcd57qohkikhmm/Szff7f8OJYpgyMIXt/OSuzi1r/M4wxph1qdqIXkTCcBH6fqhY3stslwNdHdducpaojgAtwun3GNnSgqk5X1QxVzYiPj29uWM3XeyJUFjE5cDX+vmI3TxljOo1mJXoR8cdJ8q+r6nvH2fU6juq2UdVd7mMu8D4w6sRCPUk9xkFoF0I3vMvYPvF8smoPtbXWfWOM8X5NJnpxxiH+FVivqs8cZ79IYBzwQb2yUBEJr3sOTALWnGzQJ8TXDwZdARs/57IBoewuqmD5zkKPhGKMMadSc1r0Y4CbgXNEZIW7XSgid4jIHfX2uxyYo6pl9coSgK9EZCXwDfCJqn7WatG31OBroOYg5/INAX4+fGzdN8aYTkDa4+iTjIwMbY2Vz4+hCpvmQs9x3P7GalZlF7J42kR8fOzmKWNMxyYiyxobwu7dUyAcTQT6TgK/QC4eksTe4kqW7djv6aiMMaZNda5ED86sll/+lvOr5xHo58PHK637xhjj3TpfovfxgY2fE5T5Iuf078Inq3OorrGFw40x3qvzJXqAIddAzmpu7FVOfmklCzZ5/0RIxpjOq3Mm+oFXgPhwRum/iQsL4K1vd3o6ImOMaTOdM9GHJ0DP8fiuncXlw7ry7/W55JdWejoqY4xpE50z0QMMuxESh3DdkAiqa5X3v9vl6YiMMaZNdN5EP/gquO51eqV2Y3hqFG9n7rQZLY0xXqnzJvo6ud9zw7BYNuWW2pQIxhiv1LkTfcEWeH40Uw5+QrC/L+9k2kVZY4z36dyJPrYX9D6XwKV/5vL0CD5auYcDB6s9HZUxxrSqzp3oASY8AuX7uSt4DqWV1cxenePpiIwxplVZok8eAf0vJnn9Xxkcq7xtY+qNMV7GEj3A+IeR2mqm9tzHN1n72JpX6umIjDGm1ViiB0gcBPd/z6jzrsFHYNaybE9HZIwxraY5K0x1E5F5IrJORNaKyL0N7DNeRIrqLUzyWL26ySKyQUQ2i8i01j6BVhMUSUJ4INf2rGLWsmyb6MwY4zWa06KvBu5X1XRgNM4C3+kN7LdQVYe52xMAIuILPIezMHg6cH0jx7YPC/+P3+T8F1qy1yY6M8Z4jSYTvaruUdXv3OclwHoguZnvPwrYrKpbVfUgMBO49ESDbXMDr8CntoqfB39sE50ZY7xGi/roRSQNGA4sbaD6DBFZKSKfishAtywZqJ8xs2n+l8SpF9sLGXYDVzOXtevXk1tS4emIjDHmpDU70YtIGPAucJ+qFh9V/R3QXVWHAn8G/tXSQERkqohkikhmXp4Hu03GPYSvwJ2+7/PKwm2ei8MYY1pJsxK9iPjjJPnXVfW9o+tVtVhVS93nswF/EYkDdgHd6u2a4pYdQ1Wnq2qGqmbEx8e38DRaUVQqMvIWJgetZebiLRTY9MXGmA6uOaNuBPgrsF5Vn2lkn0R3P0RklPu+BcC3QB8R6SEiAcB1wIetFXybmfDfFN7yJSXVwsvWqjfGdHB+zdhnDHAzsFpEVrhljwCpAKr6InAVcKeIVAPlwHXqzPlbLSI/AT4HfIEZqrq2lc+h9YXE0CsELh0cz/zFi5g6ticxoQGejsoYY06ItMc52DMyMjQzM9PTYVDy2g0UbVnKO6e/y88uHObpcIwxplEiskxVMxqqsztjjyN83E9IkXxClv6JwgMHPR2OMcacEEv0x9P9TIr6XskP+ZD35n7p6WiMMeaEWKJvQuQlT1HrG0i/735DkbXqjTEdkCX6poQnUDj6ISK1iDfmr/R0NMYY02KW6Jsh6dyf8myv6bywdB/FFVWeDscYY1rEEn1z+Phyz7n9kYpCFn7wd09HY4wxLWKJvpkGJUfyTPwnTFr/C8p2tf9bAYwxpo4l+hZIvOQxDmgg+e/cB+3w/gNjjGmIJfoWGNi3N593uY3uhd9Q+N0xU/4YY0y7ZIm+hU6/5iG+127UfvYIVJV7OhxjjGmSJfoW6h4fydL+0/i2IoWNO3M8HY4xxjSpOZOamaNceuk1jNvYhSHzcnmtRxruxJ3GGNMuWYv+BESFBHDPxD5kb17N1vd/7elwjDHmuCzRn6CbR3fn+rDl9Fr1e2o2fuHpcIwxplGW6E9QgJ8P3S9+iG21CZR+cD9U2zw4xpj2yRL9SZg0JJWZMXcRWZZFxdfPeTocY4xpUHOWEuwmIvNEZJ2IrBWRexvY50YRWSUiq0VkkYgMrVeX5ZavEBHPrybSikSEC6/8IV/UDMdn/v+D4j2eDskYY47RnBZ9NXC/qqYDo4G7RST9qH22AeNUdTDwa2D6UfUTVHVYY6ufdGRDu0WxqM8D/KN6AnvKfT0djjHGHKPJRK+qe1T1O/d5CbAeSD5qn0Wqut99uQRIae1A27PbLjmH39XezG/n7fJ0KMYYc4wW9dGLSBowHFh6nN1+BHxa77UCc0RkmYhMPc57TxWRTBHJzMvLa0lYHpcSHcLUs3uyeeXXlD5/DpTmejokY4w5pNmJXkTCgHeB+1S1uJF9JuAk+l/UKz5LVUcAF+B0+4xt6FhVna6qGaqaER8f3+wTaC/umtCL8LBwAnJXoR/da5OeGWPajWYlehHxx0nyr6tqg7N5icgQ4BXgUlUtqCtX1V3uYy7wPjDqZINuj0IC/Lj+onN5uupqZMNsWPWWp0MyxhigeaNuBPgrsF5Vn2lkn1TgPeBmVd1YrzxURMLrngOTgDWtEXh7NGVoV1al3MBy+qOzH4Ti3Z4OyRhjmtWiHwPcDJzjDpFcISIXisgdInKHu89jQCzw/FHDKBOAr0RkJfAN8ImqftbaJ9FeiAi/nDKEn1VOparqICx53tMhGWNM05OaqepXwHFn7VLVHwM/bqB8KzD02CO816DkSM44bRRXLXuM3w++kT6eDsgY0+nZnbFt4IFJfcny780TszegpXnWhWOM8ShL9G0gNiyQn53Xl0Wb9lL+wjnw7o+hpsrTYRljOilL9G3kptHd6ZUQyTNVl8P2r2HOLz0dkjGmk7JE30b8fX14/JKBvFJ8OiuSr4elL8CKNz0dljGmE7JE34bG9I5j8sBEbtpxERUpY+Cje2H3Ck+HZYzpZCzRt7HHp6SDTwA/15+hQ6+D6DRPh2SM6WQs0bexpMhgHjy/H7O3HOTD1F9AcBRUVdhCJcaYU8YS/Slw0+juDOsWxRMfraOwuBj+fiF8/rCnwzLGdBKW6E8BXx/hqSsGU1hexf/O2Qbdz4RvX4Flr3o6NGNMJ2CJ/hQZkBTB7Wf35O3MbBb3+Cn0nACf3A+bbGFxY0zbskR/Ct07sQ/dYoL57w/WU3H5DOgyAN66EbYv8nRoxhgvZon+FAoO8OXJywazNb+M55fkw83vQ49xEJHc9MHGGHOCLNGfYmP7xnPZsK688OVmNpUGwo1vQ3R3qK2FomxPh2eM8UKW6D3g0YvTCQ304+H3VlNT665E9cXjMH0C5G/2bHDGGK9jid4D4sICefSidDK37+flhVudwuE3g9bCa1Ng/3bPBmiM8SqW6D3kyhHJXDAokd/P2cCaXUUQ3xd+8C84WAqvXQrFezwdojHGSzRnKcFuIjJPRNaJyFoRubeBfUREnhWRzSKySkRG1Ku7RUQ2udstrX0CHZWIM7Y+NjSQe2Yup/xgDSQOhpveg7I8eP1qqK3xdJjGGC/QnBZ9NXC/qqYDo4G7RST9qH0uAPq421TgBQARiQEeB07HWRT8cRGJbqXYO7yokAB+f81QtuaV8ZtP1jmFKRlw4yyY9AT4+Ho2QGOMV2gy0avqHlX9zn1eAqwHjh4PeCnwmjqWAFEikgScD8xV1X2quh+YC0xu1TPo4Mb0jmPq2J68vnQHc9ftdQq7nwG9znGer/0XlOV7LkBjTIfXoj56EUkDhgNLj6pKBnbWe53tljVW3tB7TxWRTBHJzMvLa0lYHd79k/qSnhTBL95dRW5JxeGK0lz4153w6hQoK/BcgMaYDq3ZiV5EwoB3gftUtbi1A1HV6aqaoaoZ8fHxrf327Vqgny/PXj+MsspqHnhnFbV1Qy7DusB1b8C+LfCPS+HAPs8GaozpkJqV6EXEHyfJv66q7zWwyy6gW73XKW5ZY+XmKL27hPPoRQNYsDGPVxdnHa7oNcFJ9nkb4R+XWbI3xrRYc0bdCPBXYL2qPtPIbh8CP3BH34wGilR1D/A5MElEot2LsJPcMtOAm0Z355z+XXjq0++dIZd1ek+E616H3PWwYbbnAjTGdEiiqsffQeQsYCGwGqh1ix8BUgFU9UX3y+AvOBdaDwC3qmqme/xt7v4AT6rq35oKKiMjQzMzM1t+Nl4gv7SSKX/+CgU++MkYuoQHHa4s2AKxvZznBw9AQIhHYjTGtD8iskxVMxqsayrRe0JnTvQAa3YVcfWLi+mfFM6bt48myP+oYZY5a5xunIuegfQpngnSGNOuHC/R252x7dCg5EieuWYoy3cU8sh7qznmyzg80Vl79u0fwKI/Qzv8sjbGtB+W6NupCwYn8fPz+vLe8l28tGDrkZWhcXDLR05rfs6j8MnPoabaM4EaY9o9S/Tt2E/P6c3FQ5L43Wff80XdzVR1/IPhqr/DmHshc4azNKExxjTAEn07JiI8fdVQBidHcu/M5WzIKTlyBx8fOO8JZ/hlxm1OmXXjGGOOYom+nQsO8GX6zRmEBvrxo1e/paC08tid+l8EfgHO3bOvTLSlCY0xR7BE3wEkRgbx8g8yyCup5La/f0tpZSP98QdLoKLYmTJh+eunNkhjTLtlib6DGNotiuduGMGa3cVMfS2TiqoGpjCOToMfz4W0MfDBXTDnlzbVsTHGEn1Hcm56Ak9fNYRFWwq4d+Zyqmtqj90pONqZ5vi0H8OiZ53NGNOpWaLvYK4YkcLjl6Tz+dq9PNzQGHsAX3+46Pdw8R9gyLVO2f7tUF54aoM1xrQLfp4OwLTcrWN6UHigij/9exNRIf48cuEAnFkojlI3Egfgw5/A3rVwzqMw4hZb1MSYTsRa9B3Ufef24ZYzuvPywm08/+WWpg8479cQ1w8+/hm8eDZsnd/2QRpj2gVL9B2UiPD4JQO5bFhXnv58A68uyjr+AV2Hwa2z4epXndE5r02BNQ3NOG2M8TbWddOB+fgIT189lLKDNTz+4VpKK6u5a3yvhrtxAERg4GXQdzIseR76XeCU11Q5/frGGK9kLfoOzt/Xh+dvHMHlw5N5+vMNPPnJ+oYv0B5xUBCc/XNnGoXKUqcr5+tnobaBUTzGmA7PWvRewN/Xh99fPZTIYH9e+WobReVVPHXFYPx8m/E9XlvlzHE/95ew+Qu4/EWI6Nr2QRtjTpnmrDA1Q0RyRWRNI/UPisgKd1sjIjUiEuPWZYnIareu804wfwr4+AiPX5LOfef24Z1l2dz1+ncN31R1tOBouPafMOXPkP0tPH8GLJ1uN1oZ40Wa03Xzd5yVoxqkqk+r6jBVHQY8DMxX1foLm05w6xucEN+0HhHhvnP78qtL0pmzbu/xp0s48kAY8QO44ytIHAxr3wexXj1jvEWT/5pVdQHQ3BWprwfePKmIzEn74Zge/PHaYSzdto/rpi8mp6iieQfG9nLmub/hLSf5l+TAjMmw5T9tG7Axpk21WrNNREJwWv7v1itWYI6ILBORqa31WaZplw1P5pUfZLAtr4wpf/mKlTubeVesCARFOM8Ld0LxLvjH5c5EaTmr2y5gY0ybac3f55cAXx/VbXOWqo4ALgDuFpGxjR0sIlNFJFNEMvPy8loxrM5rQv8uvHvXmQT4+XDNS4v5aOXulr1Bt9PgJ5kw+bdOkn9pLHw6zUbnGNPBtGaiv46jum1UdZf7mAu8D4xq7GBVna6qGaqaER8f34phdW79EyP44O4xDE2J4qdvLueZORuorW3B4iR+gTD6TrjnO2dKhcoSZ8ETY0yH0Sr/YkUkEhgHfFCvLFREwuueA5OABkfumLYVGxbIP398OtdkpPDsfzZz9xvfceBgC9eYDY52Jkq79C/O690r4O8XO/PnGGPateYMr3wTWAz0E5FsEfmRiNwhInfU2+1yYI6qltUrSwC+EpGVwDfAJ6r6WWsGb5ovwM+H3105hEcvGsBna3O4+sXFbC8oa/rAo9XddVuyx0nyL54Nsx+Ekr3HP84Y4zHS5F2UHpCRkaGZmTbsvq3M+z6Xe2cup1bhqSsGc8nQE7xB6sA++PcT8N1r4Bvg3G077qHWDdYY0ywisqyxYezW2doJTejfhdn3nk3fhDB++uZyHn5vFeUHT+AGqZAYuOSP8JNvIX3K4YXJVaGiqHWDNsacMGvRd2JVNbU8M3cjL3y5hb4JYTx3wwj6JISf+BuqOl076z+CD+6G0Xc5N2LZlArGtDlr0ZsG+fv68IvJ/XnttlHsKzvIJX/5ire+3dH0pGiNqeu/j+kFqWfCl0/BHwbC61fDug9tWgVjPMRa9AaA3OIKfvb2Cr7eXMB56Qk8edkgukQEndybFmyBFW84m18g3LPc+TIoK4DQ2NYJ3OobbsYAABbkSURBVBgDHL9Fb4neHFJTq8z4ahv/N2cDgX4+/PLidK4amdL4/PbNVVsDRTshOg2qKuCZ/hA/wBmXnz7F+RIwxpwU67oxzeLrI9w+tief3TeW/okRPDhrFT/827fsKiw/uTf28XWSPIDWwFk/c4ZnvvdjeGYAzH0MinaddPzGmIZZi940qLZW+ceS7fzus+/xEeHhC/tzw6jUk2/dH/4A2PYlfPtX2PAp3PoppJ4O5fshIMxWvDKmhazrxpywnfsOMO29VXy9uYBRaTH85vJB9D2ZkTkNKcmBsASn//6j+2DDbBh+kzNip+6XgDHmuKzrxpywbjEh/PNHp/O7KwezMbeEC/+0kKdmr6esOfPcN1d44uEROwMugeSR8NUf4E9D4bXLnBa/MeaEWaI3TRIRrj0tlf/cP54rRiTz0oKtnPfMfD5bk3PiQzEb03siXP8m3LcGxj8C+Ztg01ynThV2fmOzZxrTQtZ1Y1osM2sfj/5rDd/nlHBO/y786pKBpMaGtM2H1dbAwTJnjvwdS2HGJIhMhcFXwZBroUv/tvlcYzoY66M3ra6qppZXF2Xxh7kbqapRbjurB3dP6EV4UBteRK0she8/gVVvwdZ5oLUQ2xtufAdiejpTKPuHOKN8jOlkLNGbNpNTVMH/+/x73vtuF3FhAdw/qR/XZHTD16eVRuc0pjQX1rwH2xbAVTPAPwjm/BIyZ0DX4dB9DAy6AuL7tW0cxrQTluhNm1u5s5Bff7yOzO376Z8Yzi8vTmdM77hTG8SW/zgXbnd+AzmrnBZ/6plw6+zDF3uN8VKW6M0poarMXp3D/85ez67Ccib278LPzuvLoOTIUx9MyV5Y+77TnTPuQafsX3dB8gjoewFEJp/6mIxpQ5bozSlVUVXDjK+38eKXWyiuqOb8gQncd25fBiRFeC6oA/tgxmTI3+C87pIOvc+FYTdAlwGei8uYVnJS4+hFZIaI5IpIg8sAish4ESkSkRXu9li9uskiskFENovItBM/BdORBPn7ctf43iz8xTncO7EPizYXcMGfFnLX68vYkFPimaBCYuDupXDXEjjv1xAaB0teOLwU4r6tsPQl2LMSalrxHgFj2oEmW/QiMhYoBV5T1UEN1I8HHlDVi48q9wU2AucB2cC3wPWquq6poKxF712KDlTxyldb+dvXWZQdrOaiwUncPaG3Z1v44C507gf+wbDsVfjoHqc8IAxSToPUM2DU7c6XhDHt3PFa9H5NHayqC0Qk7QQ+dxSwWVW3ukHMBC4Fmkz0xrtEhvhz/6R+3DamBy8v3Mqri7L4eNUeJvbvwl0TejOye7RnAgusN5XDyFucm7V2LIEdi53HBf8PRt/p1C95AbK+goRBkDAQEgdBVBr42D2Hpv1rMtE30xnuIuC7cVr3a4FkYGe9fbKB0xt7AxGZCkwFSE1NbaWwTHsSHRrAQ5P7819je/Hq4ixmfL2NK19YxOieMdw9oTdn9Y5rvUnTTkRkinMj1uCrnNeVpRAY5jyvroS8Dc48POremRvZDe5d5ST7qgpniKcx7VCzLsa6LfqPG+m6iQBqVbVURC4E/qSqfUTkKmCyqv7Y3e9m4HRV/UlTn2ddN51DWWU1b36zg5cXbmVvcSVDUiL50Vk9uHBwEv6+7bSlfPAA5K2HnDVQWQxn/tQpf2mccxdv74nQcxwkDrXFVcwpddKjbo6X6BvYNwvIAPoAv1LV893yhwFU9amm3sMSfedSWV3D+9/t4qUFW9mWX0ZCRCA3nd6d609PJS6sAyxKogpf/xE2fQE7l0CtezH3tB/DRb936le/A/H9nRE+NgWzaQNtmuhFJBHYq6oqIqOAWUB3oO5i7ERgF87F2Bvcbp3jskTfOdXWKvM35vG3RVks2JhHgK8Plwztyq1j0jwzFv9EVBRD9reQu85J7H3OcxZV+UO6U+8XBElDITkDhlwDXYd5Nl7jNU4q0YvIm8B4IA7YCzwO+AOo6osi8hPgTqAaKAd+rqqL3GMvBP6Ik/RnqOqTzQnYEr3ZnFvKa4uzmLUsmwMHaxiRGsXNZ3TngkFJBPl3sLlsamugYDPkrIbdyyE7E/asgEufc64H7FgKn9zvTNwWFAmB7uOo2yGuj/OLwO7sNU2wG6ZMh1VUXsWsZdn8c8l2tuWXER3izzWndePGUd3bbsbMU6Gmyrmo6xcIu5bB/KedPv+KYqgogsoiuP4t6H4GrHrHGQHUY6yzpZ1tQz7NMSzRmw6vtlZZtKWAfy7Zztz1e6lVZVzfeK4flcqEfl0I8GunF29bw+YvYMmLsH0RVJUB4szaeefXzhfFpi+gMAsikiGiqzMaKDjafgV0MpbojVfJKargzW92MPPbHewtriQqxJ+LhyRx+fBkRqRGe3aIZluqPgi7v4Ot852uoCtfdspn/QjWzDpy3/AkuP975/myV53ZPqO7O0szRqdBaLx9EXgZS/TGK1XX1LJwcz7vf7eLOetyqKiqJTUmhMuGJ3P58GR6xIV6OsRTo6YaynKheA8UZzsXf2sOwln3OfWvXwObPj/ymK7DYeqXzvPMGeDjD3F9Ib6v82vAdDiW6I3XK62s5rM1Ofxr+S6+3pKPKgxJiWTK0K5cPKQriZGd/GamqnIo3AH7s5zNLxBG/tCp++Ngp65OaLyzetf57tiJb152poUIjXO3Ls5i7r6tdb+laQ2W6E2nklNUwUcrd/Phyt2s3lWECJzeI4ZLhyVzwaBEokICPB1i+1JbA4XbIW8j5LtbwkBn+ofaWvh1HGjNkcdk/Agufsb5NfH2zc4C7+FJzmNYIiSkO3cam1PGEr3ptLbmlfKhm/S35pXh5yOc2TuOyQMTmTQwoWPckOVJqnCwFMryoCzfeSzd63TzpJ3lTP/86iVQvBvK9x0+buJjcPb9UJQNr5zrXCiOTHG3bs4dxHF9nNFHVQcgINzmDTpJluhNp6eqrN1dzEerdvPZmhy2FxzARyAjLYYLBiUyeVAiSZHBng6zY6uudL4ESvZCeAJEpTpfAP950r124G7VFXD5dBh6rTOS6G8XAOJMMhcY4dxPMPkp6DnemT56/cfuL4WEw78aAsPtYvJRLNEbU4+qsn5PCZ+tzeGzNXvYuLcUgMHJkZw7IIFz07uQnhThvaN3PEkVDhQ4dwgHhkHhzsMrgdXdR1BZDOMecu4gXvMuzLrt2Pe59TPnHoOtX8LyfzrXDcC5N0FrYNw0Z66hTV849ynE9XG22N7OtNReyBK9McexJa+Uz9bk8O/1e1m+sxBV6BoZxMQBCUwc0IXRPWM73t243kLV+RIoyYHSHOexJMdZGSw0zrmZbN5vnG4lAPF1Wvr/Nd8ZRrr0Jfj0oXpvKE7X0R0LITgKVrzpLDDv4+Me6+N8AZ37hFNWuNN5v7DEdn/x2RK9Mc2UV1LJvO9z+WL9XhZuyqe8qoYgfx9G94xlbJ94xvaNp1d8qLX2O5KDB2DfFvdC82bnwvOlzzkJ/N+/hlVvORektcb5RSC+8IC75OQ7t8La95wvgKAo55iIrnDHV079zBuddQrCEpx1iCO6OmsW1K1jkLPaGerqG+Bu/s4IprAurX6aluiNOQEVVTUs3lLAlxtyWbApn235ZQAkRwUztm884/rGcUavOCKDbTZKr1J/bqGd3zjLTRbvdrqcwJl+4pxHneffvgK53zu/Nop3O/cwRHeHH81x6l84C/auPvL9e4yFWz5ynr88Ecr3O/cuhMTAhf/nHH8CLNEb0wp27jvA/I15LNiYx6ItBZRWVuMjMLRbFGf3iWdsnziGdotqv3Ppm1OjtvbwCKIdS6Gi0GnV1xx0RhmFxEGfc536uY9D0U4n2R/YB9e9fsLDUi3RG9PKqmpqWbGzkIUb81iwKZ9V2YXUKoQH+jG6Vyxn9IzljF6x9EsIx8fHunlM27NEb0wbKzxwkEVbCli4KY+vNuezc185ANEh/ox2k/7onrH0jg+zxG/axEktDm6MaVpUSAAXDk7iwsFJAGTvP8DiLQUs3lrAki0FfLomx93Pn5Gp0YxMiyajewxDUiJtRI9pc9aiN6aNqSo795WzZGsBmdv3kbl9P1vznAu7Ab4+DEqOYERqNMNToxmeGkVSZJCN6jEtdrIrTM0ALgZyG1lK8EbgF4AAJcCdqrrSrctyy2qA6saCOJoleuPtCkorWbZ9P8u27ydz+35W7yriYHUtAF3CAxmeGsXw1GiGdYticHIkoYH249sc38l23fwd+AvwWiP124BxqrpfRC4ApgOn16ufoKr5LYjXGK8XGxbIpIGJTBqYCMDB6lq+zylm+Y5Clu/Yz4qdhXy+di8APgJ9uoQztFskQ7tFMaxbFH0Twm10j2m2k14c/Kj9ooE1qprsvs4CMlqa6K1FbwzsKzvIyuxCVu50t+wi9pUdBCDI34dBXSMZ1i3qUPJPiQ62Lp9O7KRH3bQg0T8A9FfVH7uvtwH7AQVeUtXpxzl2KjAVIDU1deT27dubjMuYzkRVyd5fznI38a/YWciaXUVUul0+cWEBDE2JYlBypLtFkBhh/f2dxSlJ9CIyAXgeOEtVC9yyZFXdJSJdgLnAT1V1QVOfZy16Y5qnqqaWDTklLN9ZyIodhazKLmRLXim17j/r2NAABiZHMqhrBOldIxiQFEFabCi+NsTT67T58EoRGQK8AlxQl+QBVHWX+5grIu8Do4AmE70xpnn8fX0OteBvHu3cOn/gYDXr95SwZleRs+0uZvqCrVS72T/Y35d+ieEMSIogPSmc/kkR9O0STmSITeXgrU460YtIKvAecLOqbqxXHgr4qGqJ+3wS8MTJfp4x5vhCAvwY2T2akd0Pr/1aUVXD5txS1u0pZr27zV69hze/ObyEYGJEEH0Tw+mXEEbfBOeLoE9CGIF+Ns6/o2sy0YvIm8B4IE5EsoHHAX8AVX0ReAyIBZ53+wLrhlEmAO+7ZX7AG6r6WRucgzGmCUH+voda/nVUlT1FFWzIKWHD3hI2uo+vbi04NNTTz0fo3SWM9K4RDOwaSXpSBOlJEdb672DshiljzBFqapWsgjK+31PC2t1FrNtTzLrdxeSWVB7aJyY0gB5xofSMC6VHvPsYF0ZaXIj9AvAQm+vGGHPS8koqWbenmA05xWzLP8DWvFK25Zcd8QXgI5AaE0LvLuH07hJ2aOvTJcxu+mpjNteNMeakxYcHMi48nnF9448oL62sJiu/jC15pWzOPbzN35hLVc3hhmRKdDD9EsLpkxBOv0TnOkDPuDCCA+wXQFuzRG+MOSlhgX7H9P+DM/Rzx74DbNpbyubcEjbsLWVjTgkLNuUd8QWQHBXsdAPFh9IjLtTtEgojOTrYhoG2Ekv0xpg24e/rQ6/4MHrFhwGJh8qramrJyi9jw94StuaVHeoCev+7XZRUVh/aL8DXh9TYkEPJv0dcKGmxzmNCRKDdCNYCluiNMaeUv68PfdwunPpUlfzSg2zNKyWroIyt+WVk5ZexLb+M+RvzDo0EAudegO7ul0BaXCg9YkNJiQkmNSaEpEj7JXA0S/TGmHZBRIgPDyQ+PJDTe8YeUVdTq+wpKmfboeR/gKyCMjbklDB33d5DN4OBMyS0a5ST9LvFBJMSHUK3mBBSooPpFh1CXFhAp/s1YIneGNPu+foIKdEhpESHcHafIy8GV9fUsquwnJ37ytm5/wA79x1gx74D7Nxfzpy1eylwJ4KrE+TvQ0p0CN1jQkiLCyUttu4xlK5R3vlrwBK9MaZD8/P1oXtsKN1jQxusL6usJnt/Odnul8DO/eWHvgy+3pJPRdXhLqEAXx+6xQST5r5fWlwIabF1XwJB+HXQqaEt0RtjvFpooB/9EsPplxh+TF1trZJbUsm2/DK2F5SxraCM7W630KItBZRX1Rza189HSIoKolv04W6gui6h1JgQ4sPb7wViS/TGmE7Lx0dIjAwiMTKIM3odeV1AVckrqSSrwEn82wvKyHZ/DczbkEdevRvF4HCXUGqMs6VEB9MtJsT9QggmPMhz00ZYojfGmAaICF0igugSEcSoHjHH1FdU1bjdQUdeG9ixr5xvtu2jtN5QUYDoEP9DiT8lJpiUqGCSo52LxclRwW1657AlemOMOQFB/r7uVA/HdgmpKvsPVLnXBI78Mli3p5i56/ZysKb2iGOiQvzp0yWMd+44s9VjtURvjDGtTESICQ0gJjSAod2ijqmvrVXySisPXSTeVVhO9v5yamvbZu4xS/TGGHOK+fgICRFBJEQEHbFuQJt9Xpt/gjHGGI9qVqIXkRkikisiaxqpFxF5VkQ2i8gqERlRr+4WEdnkbre0VuDGGGOap7kt+r8Dk49TfwHQx92mAi8AiEgMzopUp+OsF/u4iLT97xRjjDGHNCvRq+oCYN9xdrkUeE0dS4AoEUkCzgfmquo+Vd0PzOX4XxjGGGNaWWv10ScDO+u9znbLGis/hohMFZFMEcnMy8trpbCMMca0m4uxqjpdVTNUNSM+Pr7pA4wxxjRLayX6XUC3eq9T3LLGyo0xxpwirZXoPwR+4I6+GQ0Uqeoe4HNgkohEuxdhJ7llxhhjTpFm3TAlIm8C44E4EcnGGUnjD6CqLwKzgQuBzcAB4Fa3bp+I/Br41n2rJ1T1eBd1AVi2bFm+iGxv2akcEgfkn+CxHZmdd+di5925NOe8uzdWIaptc8utp4hIpqpmeDqOU83Ou3Ox8+5cTva8283FWGOMMW3DEr0xxng5b0z00z0dgIfYeXcudt6dy0mdt9f10RtjjDmSN7bojTHG1GOJ3hhjvJzXJHoRmSwiG9ypkqd5Op621NC00SISIyJz3emg53rbLKEi0k1E5onIOhFZKyL3uuVefd4AIhIkIt+IyEr33P/HLe8hIkvdv/m3RCTA07G2NhHxFZHlIvKx+9rrzxlARLJEZLWIrBCRTLfshP/WvSLRi4gv8BzOdMnpwPUiku7ZqNrU3zl2FtBpwL9VtQ/wb/e1N6kG7lfVdGA0cLf7/9jbzxugEjhHVYcCw4DJ7h3ovwP+oKq9gf3AjzwYY1u5F1hf73VnOOc6E1R1WL3x8yf8t+4ViR5nrvvNqrpVVQ8CM3GmTvZKjUwbfSnwqvv8VeCyUxpUG1PVPar6nfu8BOcffzJeft4A7vTfpe5Lf3dT4BxgllvudecuIinARcAr7mvBy8+5CSf8t+4tib7Z0yF7sQR3fiGAHCDBk8G0JRFJA4YDS+kk5+12YawAcnHWddgCFKpqtbuLN/7N/xF4CKh1X8fi/edcR4E5IrJMRKa6ZSf8t26Lg3shVVUR8cpxsyISBrwL3KeqxU4jz+HN562qNcAwEYkC3gf6ezikNiUiFwO5qrpMRMZ7Oh4POEtVd4lIF2CuiHxfv7Klf+ve0qK36ZBhr7uqF+5jrofjaXUi4o+T5F9X1ffcYq8/7/pUtRCYB5yBs5JbXWPN2/7mxwBTRCQLpyv2HOBPePc5H6Kqu9zHXJwv9lGcxN+6tyT6b4E+7hX5AOA6nKmTO5MPgbrF128BPvBgLK3O7Z/9K7BeVZ+pV+XV5w0gIvFuSx4RCQbOw7lGMQ+4yt3Nq85dVR9W1RRVTcP59/wfVb0RLz7nOiISKiLhdc9xpndfw0n8rXvNnbEiciFOn54vMENVn/RwSG2m/rTRwF6caaP/BbwNpALbgWuaMyV0RyEiZwELgdUc7rN9BKef3mvPG0BEhuBcfPPFaZy9rapPiEhPnNZuDLAcuElVKz0Xadtwu24eUNWLO8M5u+f4vvvSD3hDVZ8UkVhO8G/daxK9McaYhnlL140xxphGWKI3xhgvZ4neGGO8nCV6Y4zxcpbojTHGy1miN8YYL2eJ3hhjvNz/B6SecKeiOP9nAAAAAElFTkSuQmCC\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
], | |
"source": [ | |
"# Plot the training results\n", | |
"plt.title(\"Loss curves\")\n", | |
"loss_curve, = plt.plot(np.array(train_losses), label='train')\n", | |
"plt.plot(np.array(valid_losses), linestyle='--', label='valid')\n", | |
"plt.legend();" | |
], | |
"id": "f7M5Et3PuEaa" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 6. Create a top-k accuracy \n", | |
"\n", | |
"\n", | |
"Where we want to check if the true label appears in the topk-predicted class.\n", | |
"It would make sense to see that higher topk range leads to better accuracy.\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "oSn04x9E9iVG" | |
}, | |
"id": "oSn04x9E9iVG" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fe1f70cf" | |
}, | |
"outputs": [], | |
"source": [ | |
"# define the top-k accuracy function \n", | |
"def topk_accuracy(k_list, lstm_model, data_loader, many_to_one):\n", | |
" # use the lstm mode in evaluation mode and create a dict for accuracies\n", | |
" lstm_model.eval()\n", | |
" accuracies = {k : [] for k in k_list}\n", | |
"\n", | |
" # Disabling gradient calculation is useful for inference\n", | |
" # It will reduce memory consumption for computations\n", | |
" # We evaluate performance and do not need the gradients here\n", | |
" with torch.no_grad():\n", | |
" for batch in iter(data_loader):\n", | |
"\n", | |
" logits = lstm_model(batch[:,1:])\n", | |
" softmax = torch.nn.functional.softmax(logits, dim=2)\n", | |
" batch = batch[:,1:]\n", | |
" # softmax = softmax[:,-1:,:]\n", | |
"\n", | |
" \n", | |
" # discriminate on which setting is used, because\n", | |
" # we need to apply the softmax function on the last 2nd dimension\n", | |
" if many_to_one:\n", | |
" # perform prediction on batch of sequence input and\n", | |
" # use softmax for further calculations of results\n", | |
" logits = lstm_model(batch[:,1:])\n", | |
" softmax = torch.nn.functional.softmax(logits, dim=2)\n", | |
" batch = batch[:,-1:]\n", | |
" softmax = softmax[:,-1:,:]\n", | |
" # else:\n", | |
" # batch = batch[:,1:]\n", | |
" \n", | |
" \n", | |
" for k in k_list:\n", | |
" # Return the k largest elements of the given input tensor along the given dimension\n", | |
" topk = torch.topk(softmax, k, 2)\n", | |
"\n", | |
" # compare whether the topk indices match and sum up\n", | |
" # .view() returns a new tensor with the same data as the self tensor but of a different shape.\n", | |
" # .item() returns the value of this tensor as a standard Python number. This only works for tensors with one element. \n", | |
" correct = (\n", | |
" topk.indices == batch.view(topk.indices.shape[0], topk.indices.shape[1], 1)\n", | |
" ).sum().item()\n", | |
"\n", | |
" # in a many to many setting we add the second dimension of the batches as well\n", | |
" # because we not only have the last sequence\n", | |
" if many_to_one:\n", | |
" accuracies[k].append(correct / batch.shape[0])\n", | |
" else:\n", | |
" accuracies[k].append(correct / (batch.shape[0] * batch.shape[1]))\n", | |
"\n", | |
" \n", | |
" for k, accs in accuracies.items():\n", | |
" #calculate accs with the mean \n", | |
" accuracies[k] = np.mean(np.array(accs))\n", | |
" \n", | |
" return list(accuracies.values())" | |
], | |
"id": "fe1f70cf" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 281 | |
}, | |
"id": "LU0xl0EJLYPE", | |
"outputId": "6aa2aabb-e845-4bd7-febc-01316a08076f" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU5f3+8fcne9iXICBhX0QEREgC1g3cfqBWKnVDURSEuqDWr627Vq1tLVrt16+iIgLixlJEQbEq7kslBGRfA6IE0IQthED25/fHDHSIgYQw5MxM7td1zZWZc05mbg7h5uSZM+cx5xwiIhL+orwOICIiwaFCFxGJECp0EZEIoUIXEYkQKnQRkQihQhcRiRAqdAkrZtbfzLK8zlETzOx9MxvudQ4JHyp0CRoz2xNwKzOzfQGPr/YwV38zc2Z2t1cZqsM5N8g594rXOSR8qNAlaJxz9fbfgB+BXwcse93DaMOBHcC1Nfmi5qN/Y1Jj9MMmx5yZxZvZP81si//2TzOL96/rb2ZZZnafmW0zs41HcjRvZreZ2UozSz7E+rrApcAtQGczSym3fpSZrTKzPP/z9PYvb21mb5lZjpltN7Nn/csfNrPXAr6/nf/oP8b/+DMz+4uZfQ3sBTqY2fUBr7HBzH5XLsNgM1tsZrvNbL2ZDQx4rhsCthvhf56dZvaBmbX1Lzcze9rMsv3PsczMuld1H0rkUKFLTbgf6Af0Ak4G0oAHAta3AJKAVviOpseb2QmVPamZPQRcB5zlnDvUuPoQYA8wA/jA//z7v/8y4GF8R+4NgIuB7WYWDbwL/AC08+eaWpU/qN81wGigvv85soGL/K9xPfB0wH8cacAU4I9AI+BMYGMFf9bBwH3+P08z4EvgTf/q8/3f1wVoCFwObD+CvBIhVOhSE64GHnXOZTvncoBH8JVeoAedc4XOuc+B9/CV0qGYmT2Fr8gG+J/zUIYD05xzpcAbwJVmFutfdwMw1jm3wPlkOud+wPcfzvHAH51z+c65AufcV0fw553snFvhnCtxzhU7595zzq33v8bnwIfAGf5tRwITnXMfOefKnHObnXOrK3jOG4G/OedWOedKgL8CvfxH6cX4/vPoCph/m61HkFcihApdasLx+I5U9/vBv2y/nc65/PLrzaxN4ButAesb4TsC/ptzLvdQL2pmrYEBwP7x+3eABOBC/+PWwPoKvrU18IO/OKtjU7kcg8zsWzPbYWa7gAvw/UZyuAzltQX+18x2+Z9jB2BAK+fcJ8CzwHNAtpmNN7MG1cwuYUyFLjVhC75C2q+Nf9l+jf1j3Qetd879WO6N1v124hvCmGRmpx3mda/B9zM+x8x+AjbgK/T9wy6bgI4VfN8moM3+cfFy8oE6AY9bVLDNgUuY+t8rmAk8CTR3zjUC5uIr48NlqCjT75xzjQJuic65bwCcc8845/oA3fANvfyxCs8pEUaFLjXhTeABM2tmZknAQ8Br5bZ5xMzizOwMfGU943BP6Jz7DN9Qzlv+ceiKDMc3vNMr4PZb4AIzawpMAP5gZn38byx28g9hpANbgcfNrK6ZJQT8x7EYONP/20ND4N5K/uxxQDyQA5SY2SB8Q0X7vQxcb2bnmFmUmbUys64VPM8LwL1mdhKAmTX0vweAmaWaWV//UFI+UACUVZJLIpAKXWrCY0AGsBRYBizyL9vvJ3xH3VvwDY/ceIhx5IM45z4CRuA7Au8duM7M+uH7reA559xPAbfZQCYw1Dk3A/gLvrH1POBtoIl/vP3XQCd8p19mAVcEvOY0/59lIb43Tw+XMQ+4DZju/zNeBcwOWJ+O/41SIBf4nIN/m9m/3Szg78BUM9sNLAcG+Vc3AF7yP/8P+N4QfeJwuSQymSa4EC+ZWX/gNedchacdikjV6QhdRCRCqNBFRCKEhlxERCKEjtBFRCJERefZ1oikpCTXrl07r15eRCQsLVy4cJtzrllF6zwr9Hbt2pGRkeHVy4uIhCUz++FQ6zTkIiISIVToIiIRQoUuIhIhPBtDr0hxcTFZWVkUFBR4HSVsJSQkkJycTGxsbOUbi0hECalCz8rKon79+rRr1w4zq/wb5CDOObZv305WVhbt27f3Oo6I1LBKh1zMbKJ/aqvlh1hvZvaMmWWa2dLyF0k6EgUFBTRt2lRlXk1mRtOmTfUbjkgtVZUx9MnAwMOsHwR09t9GA88fTSCV+dHR/hOpvSodcnHOfWFm7Q6zyWBgivNdQ+BbM2tkZi01BZaIRJqyMkdRaRmFJWUUlpRSWOz7WlAcsKykjMLiMopKyygq2X8rPejxOSc25+TWjYKeLxhj6K04eMqtLP+yXxS6mY3GdxRPmzZtgvDSwbVr1y7eeOMNbr755qA83+TJk8nIyODZZ58NyvOJiI9zvmLdW1hKflEJ+f6v+x/vLSqhqKSM4lJHSanva3FZGSWljuJSX9kWFJWyr7iUfcVl7CsqpaDY/9h/v7Dk4FIuLi2jpCw41746rkFCyBZ6lTnnxgPjAVJSUkLuqmC7du1i3LhxQSt0kdrMOUdBcRm7C4rJKyhmd0EJu/cVk1dQQl5BCXsKi9lXVMbe4pID5bo3oFgLigOPgv979FtU4ntcWs1yNYO46CgS46JJjPXdEmKjDzxuXCeO+Ngo4mN8t7joKOJi/Lfo6AP3E2KjiI+JPrBdfOx/78cd+N7/bh/nf67YaDtmQ6PBKPTN+Ca63S/Zvyzs3HPPPaxfv55evXpx3nnnMXbsWO666y7ef/99zIwHHniAK664gs8++4yHHnqI+vXrk5mZyYABAxg3bhxRUYd+S+K9997jscceY86cOSQlJR1Ynp6ezu23305BQQGJiYlMmjSJE044gdLSUu6++27+/e9/ExUVxahRo7j11ltZsGABt99+O/n5+cTHx/Pxxx9Tv379mtg9IoBv2GFL7j425OSTnVfI9j2FbNtTyPY9RWzLL2JbXiHb8wvZkV9EcWnlpRsXE3WgWOvE+crV9zWKRomxxMf6ijA+JvpA0cbFRFEnLoY6cdHUjY+hblwMdeKjfV/jfN8fHxtNbJQRGx1FTLTva2x0FNFRkfs+UzAKfTYwxsymAn2B3GCMnz8yZwUrt+w+6nCBuh3fgD/9+qRDrn/88cdZvnw5ixcvBmDmzJksXryYJUuWsG3bNlJTUznzzDMBXxGvXLmStm3bMnDgQN566y0uvfTSCp931qxZPPXUU8ydO5fGjRsftK5r1658+eWXxMTEMG/ePO677z5mzpzJ+PHj2bhxI4sXLyYmJoYdO3ZQVFTEFVdcwbRp00hNTWX37t0kJiYGae+IHKyguJTvt+WzPmcP67PzyczZw/rsPWzYtoeC4oOnLI2PiSKpXjxJ9eJo2TCB7q0a0LRePA0SYqmfEEP9hBgaJMbSICGG+gmxNEiIpW68r8RjovX5xmCptNDN7E2gP5BkZlnAn4BYAOfcC/hmML8A3zyNe/HNjxgRvvrqK4YOHUp0dDTNmzfnrLPOYsGCBTRo0IC0tDQ6dOgAwNChQ/nqq68qLPRPPvmEjIwMPvzwQxo0aPCL9bm5uQwfPpx169ZhZhQXFwMwb948brzxRmJifH9FTZo0YdmyZbRs2ZLU1FSACp9PpDqcc2zetY9FP+5i0Q87+e7HnazYsvvAmLEZJDdOpGOzepzasSkdm9WjQ7O6tGiQQFL9eOrGResMqxBQlbNchlay3gG3BC2R3+GOpENB+R9eM2PWrFk88sgjAEyYMAGAjh07smHDBtauXUtKSsovnufBBx9kwIABzJo1i40bN9K/f/9jnl0EYO3PeXy+JoeFP+xk0Y87yc4rBCAxNpqeyQ0ZdWYHTjq+AR2b1aN9Ul0SYqM9TiyVCalPinqtfv365OXlHXh8xhln8OKLLzJ8+HB27NjBF198wRNPPMHq1atJT0/n+++/p23btkybNo3Ro0dzySWXcMkllxz4/uXLl9O2bVueeOIJhgwZwowZMzjppIP/o8rNzaVVq1aA76yY/c477zxefPFFBgwYcGDI5YQTTmDr1q0sWLCA1NRU8vLySExMPHAUL1KZktIyPlr5M6/8ZyPfbtgBQOsmiZzasSl92jamd5vGdG1RX8MgYUpNEKBp06acdtppdO/enUGDBjF27Fj+85//cPLJJ2NmjB07lhYtWrB69WpSU1MZM2bMgTdFA4u8vK5du/L6669z2WWXMWfOHDp27Hhg3V133cXw4cN57LHHuPDCCw8sv+GGG1i7di09e/YkNjaWUaNGMWbMGKZNm8att97Kvn37SExMZN68edSrV++Y7hcJf9v3FDJ1wSZe+/YHtuYW0KpRIvcM6sqQU1pxXIMEr+NJkHg2p2hKSoorP8HFqlWrOPHEEz3JcyQ+++wznnzySd59912vo1QoXPajHHtLs3Yx+ZuNvLtkK0WlZZzeKYlrT23LOSc2j+izPSKZmS10zv1y/BYdoYtEHOccn67J5vnP1rNg407qxkVzZVprrj21LZ2O0ymukUyFXg39+/fXm5cSckpKy3h36VZe+Hw9q3/Ko1WjRB66qBuXpiTTIEGXU64NQq7QnXM6/ekoeDWEJt7ZV1TK9IxNjP9iA5t37aNL83o8dfnJ/Prk44nVm5u1SkgVekJCAtu3b9cldKtp//XQExL0JldtsKewhElffc+kbzayI7+IPm0b8+jgkxhwwnFEaXy8VgqpQk9OTiYrK4ucnByvo4St/TMWSeRyzvH+8p94dM5KftpdwNldj+Om/h1JbdfE62jisZAq9NjYWM20I3IYG7fl86fZK/h8bQ7dWjZg3LDe9G7TuPJvlFohpApdRCpWUFzKC5+vZ9xn64mLjuJPv+7GNf3a6gNAchAVukiI+3JdDg++vZyN2/dyUc+WPHhRN5rrw0BSARW6SIjatqeQh2ev4N2lW2mfVJdXR6ZxRudmXseSEKZCFwlB7y3dyoPvLGdPQQl3nNuF353VQRfHkkqp0EVCyM78Ih58ZznvLt1Kz+SG/OOyk+ncXJ/ulKpRoYuEiA9X/MR9s5aTu6+IP5zfhRvP6qg3PeWIqNBFPJa7t5hH5qzgre82c2LLBkwZkUa34zV5iRw5FbqIhz5fm8Nd/1rCtj1F3HZOZ8YM6ERcjI7KpXpU6CIemfz19zzy7ko6H1ePCdem0iO5odeRJMyp0EVqWFmZY+wHa3jh8/Wc3605zww9RWewSFCo0EVqUFFJGXfPXMqs7zYzrF8bHrm4uyaakKBRoYvUkD2FJdz02kK+XLeNP5zfhVsGdNJVRSWoVOgiNSAnr5DrJ6ezamseYy/tyeUprb2OJBFIhS5yjH2/LZ/hE9PJyStkwrUpDOh6nNeRJEKp0EWOoSWbdnH95AUAvDm6H71aN/I4kUQyFbrIMfLJ6p+55fXvSKofx5QRfWmfVNfrSBLhVOgix8DU9B+5/+3ldGvZgInXpdKsfrzXkaQWUKGLBJFzjqfnreOZj9fR/4RmPHdVb+rG65+Z1Az9pIkESXFpGfe9tYwZC7O4PCWZv1zSg1hdXEtqkApdJAjyC0u4+fVFfL42h9vP6czvz+2sc8ylxqnQRY5Sdl4BIyYvYNXWPB4f0oMr09p4HUlqKRW6yFFYn7OH6yalsy2viJeu7cPZXZt7HUlqsSoN8JnZQDNbY2aZZnZPBevbmNmnZvadmS01swuCH1UktMzfsJ0h475hb2EpU0f3U5mL5yotdDOLBp4DBgHdgKFm1q3cZg8A051zpwBXAuOCHVQklLyzeDPXvJxO03pxzLr5NE7WB4YkBFRlyCUNyHTObQAws6nAYGBlwDYO2D/FSkNgSzBDioQK5xzPfpLJPz5aS9/2TXjxmj40qhPndSwRoGqF3grYFPA4C+hbbpuHgQ/N7FagLnBuUNKJhJDA0xJ/0+t4/n5pT+JjdB1zCR3BOkl2KDDZOZcMXAC8ama/eG4zG21mGWaWkZOTE6SXFjn2cvcVc92kdGYszOK2szvx9BW9VOYScqpyhL4ZCLzWZ7J/WaCRwEAA59x/zCwBSAKyAzdyzo0HxgOkpKS4amYWqVFZO/cyYvICNuTk88SlPblMl76VEFWVI/QFQGcza29mcfje9JxdbpsfgXMAzOxEIAHQIbiEveWbc7lk3DdszS3glRFpKnMJaZUeoTvnSsxsDPABEA1MdM6tMLNHgQzn3GzgTuAlM7sD3xuk1znndAQuYe2rddv43asZNKoTx+s39KVL8/peRxI5rCp9sMg5NxeYW27ZQwH3VwKnBTeaiHdmL9nCndMX0yGpHq+MSKNFwwSvI4lUSp8UFSln0tff88iclaS1a8JL16bQsE6s15FEqkSFLuLnnGPsB2t4/rP1nN+tOc8MPYWEWJ3JIuFDhS4ClJSWca//HPOhaW348+CTiNGlbyXMqNCl1ttXVMqYNxbx8epsbjunM3fo0rcSplToUqvt2lvEiMkL+G7TLv78m+5c06+t15FEqk2FLrXW5l37GD4xnR+372XcVb0Z1KOl15FEjooKXWql1T/tZvjEdPYWlTJlZBr9OjT1OpLIUVOhS60zf8N2bpiSQZ24aGbceCpdWzSo/JtEwoAKXWqVfy/fym1TF5PcOJEpI9JIblzH60giQaNCl1rj1W9/4KF3ltOrdSMmDk+lcV1dx1wiiwpdIp5zjqc/Wsszn2RyTtfjePaq3iTG6QNDEnlU6BLRSkrLeODt5UxdsInLU5L56yU99IEhiVgqdIlY+YUljHljEZ+uyWHMgE7ceX4XfWBIIpoKXSJSdl4BIyYvYOWW3fz1kh5c1beN15FEjjkVukSczOw8hk9cwI78IiYMT+Hsrs29jiRSI1ToElHSv9/BqCkZxEYb037Xj57JjbyOJFJjVOgSMeYs2cKd05eQ3CSRV65Po3UTnWMutYsKXcKec46XvtzAX+euJrVdY8Zfk6JzzKVWUqFLWCsrczz67komf7ORC3u05B+Xn6xJKaTWUqFL2CouLeOPM5bw9uItjDitPQ9ceCJRUTotUWovFbqEpYLiUm553TcpxR//3wnc3L+jzjGXWk+FLmEnr6CYG17JIH3jDk1KIRJAhS5hZfueQoZPSmf11jz+eUUvBvdq5XUkkZChQpewsWXXPoa9PJ/NO/cx/to++sCQSDkqdAkLG3L2MGzCfPIKSnh1ZF/S2jfxOpJIyFGhS8hbvjmX4RPTAXhzdD+6t2rocSKR0KRCl5D2xdocbnptIY3qxPHqyDQ6NKvndSSRkKVCl5A1I2MT9761jM7N6zP5+lSaN0jwOpJISFOhS8hxzvHMx5k8PW8tp3dK4vlhvamfEOt1LJGQp0KXkFJcWsYDs5YzLWMTQ3q34vEhPYmL0QxDIlWhQpeQkV9Ywi1vLOKzNTncdnYn7jhPMwyJHAkVuoSE/TMMrdqax9+G9GBommYYEjlSVfpd1swGmtkaM8s0s3sOsc3lZrbSzFaY2RvBjSmRLDN7D0PGfcP67HxeuraPylykmio9QjezaOA54DwgC1hgZrOdcysDtukM3Auc5pzbaWbHHavAElkWbNzBDa9ohiGRYKjKEXoakOmc2+CcKwKmAoPLbTMKeM45txPAOZcd3JgSid5bupWrJ8ynad043rrpNJW5yFGqSqG3AjYFPM7yLwvUBehiZl+b2bdmNrCiJzKz0WaWYWYZOTk51UssEWHClxsY8+YierZqyMybfkWbppouTuRoBetN0RigM9AfSAa+MLMezrldgRs558YD4wFSUlJckF5bwkhpmeOx91Yy6euNDOregqev6KUZhkSCpCqFvhloHfA42b8sUBYw3zlXDHxvZmvxFfyCoKSUiFBQXMrvpy7m3yt+YuTp7bn/As0wJBJMVRlyWQB0NrP2ZhYHXAnMLrfN2/iOzjGzJHxDMBuCmFPC3I78Iq6eMJ8PVv7Egxd148GLuqnMRYKs0iN051yJmY0BPgCigYnOuRVm9iiQ4Zyb7V93vpmtBEqBPzrnth/L4BI+tubu4+qX5pO1ax/PXdWbC3q09DqSSEQy57wZyk5JSXEZGRmevLbUnG17Crn8xf+QvbuQSdenktpO1zEXORpmttA5l1LROl0kQ46Z3L3FXPNyOlt27WPidSpzkWNNhS7HRH5hCddNTmd99h7GX5OiGYZEaoCu5SJBV1BcyqgpGSzNyuW5q3pzZpdmXkcSqRV0hC5BVVxaxi2vL+I/G7bz5GU9Gdi9hdeRRGoNFboETWmZ445pi/l4dTZ/HtydS05J9jqSSK2iQpegKCtz3PvWUt5dupV7B3VlWL+2XkcSqXVU6HLUnHP8+b2VTM/I4razO/G7szp6HUmkVlKhy1F7et46Jn29kRGnteeO87p4HUek1lKhy1GZ8OUGnvl4HZenJPPgRSdqyjgRD6nQpdqmpv/IY++t4sIeLfnbkJ4qcxGPqdClWuYs2cK9s5ZxVpdmPH1FL6J1oS0Rz6nQ5Yh9ujqbO6YtJrVtE14Y1oe4GP0YiYQC/UuUI/Lthu3c+NpCurasz4TrUkiM0+QUIqFChS5VtjRrFze8kkHrJnWYMqIvDRJivY4kIgFU6FIla3/O49qJ6TSqE8trI/vSpG6c15FEpBwVulTq590FDJ+YTlx0FK/f0JcWDRO8jiQiFdDVFuWw9hX5rpyYu6+YGTeeStumdb2OJCKHoEKXQyorc9w5YzHLNufy0jUpnHR8Q68jichhaMhFDukfH61h7rKfuP+CEzm3W3Ov44hIJVToUqGZC7N47tP1DE1rzcjT23sdR0SqQIUuv5D+/Q7ueWspv+rYlEcHd9dH+kXChApdDvLD9nx+92oGrRvX4fmr+xAbrR8RkXChf61yQO6+Yka+koEDXr4ulYZ19MEhkXCiQhcASkrLGPPGIn7Yns8Lw/rQPkmnJ4qEG522KAA89t4qvly3jbGX9qRfh6ZexxGRatARujB7yRYmf7ORkae35/KU1l7HEZFqUqHXcpnZedwzcykpbRtzz6CuXscRkaOgQq/F8gtLuOm1RSTGRvPsVb11RotImNMYei3lnOP+WctYn7OHV0fqglsikUCHZLXUa/N/5O3FW/if87pwWqckr+OISBCo0GuhJZt28ec5KxlwQjNu7t/J6zgiEiRVKnQzG2hma8ws08zuOcx2vzUzZ2YpwYsowbQzv4ibX19Es/rxPH1FL6I0ubNIxKi00M0sGngOGAR0A4aaWbcKtqsP3A7MD3ZICY6yMscd0xeTnVfAuKt706iOZh0SiSRVOUJPAzKdcxucc0XAVGBwBdv9Gfg7UBDEfBJE4z7L5LM1OTx0UTdObt3I6zgiEmRVKfRWwKaAx1n+ZQeYWW+gtXPuvcM9kZmNNrMMM8vIyck54rBSfV9nbuOpj9YyuNfxDOvX1us4InIMHPWbomYWBTwF3FnZts658c65FOdcSrNmzY72paWKNu/ax61vfkfHZvX46yU9dDlckQhVlULfDAR+HjzZv2y/+kB34DMz2wj0A2brjdHQUFBcys2vLaSopIwXrulD3Xh99EAkUlWl0BcAnc2svZnFAVcCs/evdM7lOueSnHPtnHPtgG+Bi51zGccksRyRR+asYElWLk9edjIdm9XzOo6IHEOVFrpzrgQYA3wArAKmO+dWmNmjZnbxsQ4o1TdtwY+8mb6Jm/p3ZGD3Fl7HEZFjrEq/fzvn5gJzyy176BDb9j/6WHK0lmbt4sF3VnB6pyT+cP4JXscRkRqgT4pGoB35Rdz02iKa1YvnmaGnEK0PD4nUCnqHLMKUljlue/M7cvIKmXHjqTSpqw8PidQWKvQI89RHa/gqcxuPD+mhDw+J1DIacokgH674iec+Xc+Vqa25Mq2N13FEpIap0CPE+pw93Dl9CT2TG/LwxSd5HUdEPKBCjwC5e4u54ZUM4mKiGHd1bxJio72OJCIe0Bh6mCspLeOWNxaRtXMvb4zqR3LjOl5HEhGPqNDD3GPvreKrzG38/bc9SG3XxOs4IuIhDbmEsanpPzL5m42MOK09V6TqTVCR2k6FHqbmb9jOg+8s58wuzbjvgq5exxGREKBCD0ObduzlptcX0bpxHf5v6CnEROuvUURU6GFnT2EJo6ZkUFJaxoThKTRMjPU6koiECL0pGkbKyhx3TFvMuuw9TL4+lQ66HK6IBNARehh56qO1fLTyZx648ETO6KwZn0TkYCr0MDF7yRae/TSTK1Nbc92v2nkdR0RCkAo9DCzLyuWPM5aQ1q4Jjw7urjlBRaRCKvQQl727gFFTMkiqF8/zw3oTF6O/MhGpmN4UDWEFxaWMfnUhufuKmXnTr2haL97rSCISwlToIco5x/2zlrN40y5eGNabbsc38DqSiIQ4/f4eoiZ8+T0zF2Vxx7ldGNi9pddxRCQMqNBD0Kdrsvnb+6u4oEcLbj27k9dxRCRMqNBDTGb2Hm574zu6tmjAk5edTJQmeBaRKlKhh5DcvcWMmuKbqOKl4SnUidNbHCJSdWqMEFFSWsaYN30TVbw5qh+tGiV6HUlEwowKPUT8Ze4qvly3jbG/7UmKJqoQkWrQkEsIeDP9RyZ9vZGRp7fn8tTWXscRkTClQvfYtxu28+DbyzmrSzPuHaSJKkSk+lToHtq0Yy83vbaQtk3r8H9XaaIKETk6ahCP5BUUM/KVBZQ5mDA8lQYJmqhCRI6O3hT1QGmZ4/dTF7M+J58pI9Jon1TX60giEgF0hO6BJz5Yw8ers/nTr7txWqckr+OISISoUqGb2UAzW2NmmWZ2TwXr/8fMVprZUjP72MzaBj9qZJj1XRYvfL6eq/u24Zp+2k0iEjyVFrqZRQPPAYOAbsBQM+tWbrPvgBTnXE/gX8DYYAeNBIt+3MndM5fRr0MTHr74JE1UISJBVZUj9DQg0zm3wTlXBEwFBgdu4Jz71Dm31//wWyA5uDHD3+qfdjNi8gJaNkzg+av7EKszWkQkyKrSKq2ATQGPs/zLDmUk8H5FK8xstJllmFlGTk5O1VOGuQ05exg2IZ2EmGheG9mXxnXjvI4kIhEoqIeJZjYMSAGeqGi9c268cy7FOZfSrFntmLV+0469XD1hPs45XruhL62b1PE6kohEqKqctrgZCPw8erJ/2UHM7FzgfuAs51xhcOKFt+zdBQx7eT75hSVMHX0qneunL8MAAAlrSURBVI6r53UkEYlgVTlCXwB0NrP2ZhYHXAnMDtzAzE4BXgQuds5lBz9m+NmRX8TVE+azLa+QySPSNIWciBxzlRa6c64EGAN8AKwCpjvnVpjZo2Z2sX+zJ4B6wAwzW2xmsw/xdLVC7r5irnl5Pj/u2MuE4an0btPY60giUgtU6ZOizrm5wNxyyx4KuH9ukHOFrfzCEkZMXsDan/MYf20Kp3Zs6nUkEakldO5cEBUUlzJqSgbf/biTZ648hQEnHOd1JBGpRXQtlyApKS3j1je/45v123nq8pMZ1KOl15FEpJbREXoQOOe4b9YyPlr5M49cfBJDeutzVSJS81ToQfD4v1czPSOL287pzPBftfM6jojUUir0ozT+i/W8+PkGru7bhjvO7ex1HBGpxVToR+FfC7P469zVXNizJY8O7q6LbYmIp1To1TRv5c/cPXMpp3dK4qnLTyY6SmUuIt5SoVfD/A3bueWNRXQ/vgEvXNOH+JhoryOJiKjQj9TKLbu5YUoGrRonMun6NOrF68xPEQkNKvQjsGJLLsMnpVMvPoZXR/aliS6DKyIhRIVeRTMyNjFk3DdEGUwZkUarRoleRxIROYjGCypRUFzKI3NW8Gb6Jk7t0JT/u+oUkurFex1LROQXVOiHsWnHXm5+fRHLNudyU/+O3HleF2I0dZyIhCgV+iF8tiab309bTGmpY/w1fTj/pBZeRxIROSwVejllZY5nPlnH/368jhOa1+eFYX1ol1TX61giIpVSoQfYvqeQ/5m+hM/X5jDklFb85ZIeJMbpHHMRCQ8qdL8v1uZw54wl5O4r5rHfdOfqvm30UX4RCSu1vtALS0p58oM1vPTl93Q+rh5TRqRxYkvN/yki4adWF/r6nD3c9uZ3rNiym2H92vDAhd1IiNUQi4iEp1pZ6M45pmds4uHZK4mPjdJZLCISEWpdoefuLebeWUuZu+wnftWxKU9d3osWDRO8jiUictRqTaGXlJYxPSOLp+etZWd+EfcM6sroMzoQpcveikiEiPhCd87xyepsHn9/Neuy99CnbWMmDk+lR3JDr6OJiARVRBf6sqxc/jJ3Jd9u2EH7pLq8MKw3/++kFjodUUQiUkQWetbOvTz5wRreXryFJnXjeOTik7iqbxtidR0WEYlgEVXoBcWlPPtJJuO/3IABN/XvyE39O9IgIdbraCIix1zEFPrXmdu4f9YyNm7fy296Hc9dA7tyvK5ZLiK1SNgX+o78Ih57byVvLdpMu6Z1eP2GvpzWKcnrWCIiNS5sC905x1uLNvPYeyvJKyhhzIBOjDm7kz7pKSK1VlgW+sZt+dz/9jK+ztxO7zaN+NuQnpzQor7XsUREPBV2hT49YxMPvr2cuOgoHvtNd65Ka6MPB4mIUMVJos1soJmtMbNMM7ungvXxZjbNv36+mbULdtD92ifV5ZwTj2PenWcxrF9blbmIiF+lR+hmFg08B5wHZAELzGy2c25lwGYjgZ3OuU5mdiXwd+CKYxE4tV0TUts1ORZPLSIS1qpyhJ4GZDrnNjjnioCpwOBy2wwGXvHf/xdwjunjmCIiNaoqhd4K2BTwOMu/rMJtnHMlQC7QNBgBRUSkamr0s/BmNtrMMswsIycnpyZfWkQk4lWl0DcDrQMeJ/uXVbiNmcUADYHt5Z/IOTfeOZfinEtp1qxZ9RKLiEiFqlLoC4DOZtbezOKAK4HZ5baZDQz3378U+MQ554IXU0REKlPpWS7OuRIzGwN8AEQDE51zK8zsUSDDOTcbeBl41cwygR34Sl9ERGpQlT5Y5JybC8wtt+yhgPsFwGXBjSYiIkdCFwgXEYkQ5tVQt5nlAD9U89uTgG1BjBNMylY9ylY9ylY94ZytrXOuwrNKPCv0o2FmGc65FK9zVETZqkfZqkfZqidSs2nIRUQkQqjQRUQiRLgW+nivAxyGslWPslWPslVPRGYLyzF0ERH5pXA9QhcRkXJU6CIiESLsCr2y2ZO8ZGYbzWyZmS02swyPs0w0s2wzWx6wrImZfWRm6/xfG4dQtofNbLN/3y02sws8ytbazD41s5VmtsLMbvcv93zfHSab5/vOzBLMLN3MlvizPeJf3t4/i1mmf1azuBDKNtnMvg/Yb71qOltAxmgz+87M3vU/rt5+c86FzQ3ftWTWAx2AOGAJ0M3rXAH5NgJJXufwZzkT6A0sD1g2FrjHf/8e4O8hlO1h4A8hsN9aAr399+sDa4FuobDvDpPN830HGFDPfz8WmA/0A6YDV/qXvwDcFELZJgOXev0z58/1P8AbwLv+x9Xab+F2hF6V2ZMEcM59ge9CaYECZ5Z6BfhNjYbyO0S2kOCc2+qcW+S/nweswjeBi+f77jDZPOd89vgfxvpvDjgb3yxm4N1+O1S2kGBmycCFwAT/Y6Oa+y3cCr0qsyd5yQEfmtlCMxvtdZgKNHfObfXf/wlo7mWYCowxs6X+IRlPhoMC+Sc7PwXfEV1I7bty2SAE9p1/2GAxkA18hO+36V3ON4sZePjvtXw259z+/fYX/3572szivcgG/BO4CyjzP25KNfdbuBV6qDvdOdcbGATcYmZneh3oUJzvd7mQOUoBngc6Ar2ArcA/vAxjZvWAmcDvnXO7A9d5ve8qyBYS+845V+qc64VvEpw0oKsXOSpSPpuZdQfuxZcxFWgC3F3TuczsIiDbObcwGM8XboVeldmTPOOc2+z/mg3MwvdDHUp+NrOWAP6v2R7nOcA597P/H10Z8BIe7jszi8VXmK87597yLw6JfVdRtlDad/48u4BPgVOBRv5ZzCAE/r0GZBvoH8JyzrlCYBLe7LfTgIvNbCO+IeSzgf+lmvst3Aq9KrMnecLM6ppZ/f33gfOB5Yf/rhoXOLPUcOAdD7McZH9Z+l2CR/vOP375MrDKOfdUwCrP992hsoXCvjOzZmbWyH8/ETgP3xj/p/hmMQPv9ltF2VYH/Adt+Maoa3y/Oefudc4lO+fa4euzT5xzV1Pd/eb1u7vVeDf4Anzv7q8H7vc6T0CuDvjOulkCrPA6G/Amvl+/i/GNwY3ENzb3MbAOmAc0CaFsrwLLgKX4yrOlR9lOxzecshRY7L9dEAr77jDZPN93QE/gO3+G5cBD/uUdgHQgE5gBxIdQtk/8+2058Br+M2G8ugH9+e9ZLtXab/rov4hIhAi3IRcRETkEFbqISIRQoYuIRAgVuohIhFChi4hECBW6iEiEUKGLiESI/w80/V1aMcaJBwAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
], | |
"source": [ | |
"m2m_model = NextCharLSTM(len(ALL_CHARS_NUMS), embedding_dim, hidden_dim)\n", | |
"m2m_model.load_state_dict(torch.load('best_model_m2m.pt'))\n", | |
"\n", | |
"# perform calculation on validation data\n", | |
"valid_loader = DataLoader(valid_data, 1, shuffle=False)\n", | |
"k = [i + 1 for i in range(len(ALL_CHARS_NUMS))]\n", | |
"top_k = topk_accuracy(k, m2m_model, valid_loader, many_to_one=False)\n", | |
"\n", | |
"# plot results\n", | |
"plt.title(\"Top-k Accuracies\")\n", | |
"accuracies, = plt.plot(np.array(top_k), label='top-k acc')\n", | |
"plt.legend();" | |
], | |
"id": "LU0xl0EJLYPE" | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 8. Create Text (probabilistic) \n", | |
"\n", | |
"Let's create a function that feeds some text into the model and predicts the top-1 character with a probability distribution." | |
], | |
"metadata": { | |
"id": "HP1urFvRr_Hb" | |
}, | |
"id": "HP1urFvRr_Hb" | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ac5d1a15" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.distributions import Categorical\n", | |
"\n", | |
"def text_generate_probabilistic(seed_text, encoder, lstm_model, text_length, top_k_characters=1):\n", | |
" # set up model for evaluation\n", | |
" lstm_model.eval()\n", | |
" result = encoder(seed_text.lower())\n", | |
" \n", | |
" # disable grad computations\n", | |
" with torch.no_grad():\n", | |
" # predict for each character the last topk character\n", | |
" for i in range(text_length):\n", | |
" logits = lstm_model(result.view(1, -1))\n", | |
" # use softmax for proper topk computation\n", | |
" softmax = torch.nn.functional.softmax(logits, dim=2)\n", | |
" topk = torch.topk(softmax, top_k_characters, 2)\n", | |
" \n", | |
" # create probabilistic distribution\n", | |
" categorical = Categorical(topk.values[:,-1:])\n", | |
" sample = categorical.sample()\n", | |
"\n", | |
" # concatinate the results together\n", | |
" result = torch.cat((result, topk.indices[0, -1, sample].view(-1)))\n", | |
" \n", | |
" return encoder(result)" | |
], | |
"id": "ac5d1a15" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"text_generate_probabilistic(start_text, encoder, m2m_model, 500, top_k_characters=1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 89 | |
}, | |
"id": "hQnHLzVbprjJ", | |
"outputId": "f0d62c7f-9c18-4c9f-b800-3bc56b86e16e" | |
}, | |
"id": "hQnHLzVbprjJ", | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"'america is not going to happen anymore. we have to be smart that were going to have to do it. i said what they do it is the worst things that were going to be so much more than they want to do it. i dont want to be the way it was a tough and i was going to have to make our country great again. we have to be so much money that we have to do it and i said what they do it is the worst things that were going to be so much more than they want to do it. i dont want to be the way it was a tough and i was going to'" | |
], | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
} | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 89 | |
}, | |
"id": "-vosoMjFbtRM", | |
"outputId": "444cb3d8-a30b-4c68-8b16-9fa5ddb99fe5" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"'america is an amazing all the things outside of this country is an annotance.i will start all of it. because i will tell you that. were going to be six long. and we also done. we have to be true what they did a great thing and i will be saying whats had to do this. i would have had the biggest place.were going to waik number ones whine was going to get into the most incredible. the people that want to be strong. its all over the country. its got to be obama adminisared that have been. what the pollstaria a'" | |
], | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
} | |
}, | |
"metadata": {}, | |
"execution_count": 23 | |
} | |
], | |
"source": [ | |
"text_generate_probabilistic(start_text, encoder, m2m_model, 500, top_k_characters=4)" | |
], | |
"id": "-vosoMjFbtRM" | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"machine_shape": "hm", | |
"name": "text_processing_lstm.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment