Created
January 10, 2023 16:18
-
-
Save Maximilian-Winter/1ce3a32a6dfe30aee82cecd3c0ae0d94 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "3c50ef8c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"^C\n", | |
"Requirement already satisfied: datasets in c:\\users\\maxim\\anaconda3\\lib\\site-packages (2.8.0)\n", | |
"Requirement already satisfied: pyarrow>=6.0.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (10.0.1)\n", | |
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.11.1)\n", | |
"Requirement already satisfied: pandas in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (1.4.4)\n", | |
"Requirement already satisfied: packaging in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (21.3)\n", | |
"Requirement already satisfied: responses<0.19 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.18.0)\n", | |
"Requirement already satisfied: multiprocess in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.70.14)\n", | |
"Requirement already satisfied: fsspec[http]>=2021.11.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (2022.7.1)\n", | |
"Requirement already satisfied: dill<0.3.7 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (0.3.6)\n", | |
"Requirement already satisfied: aiohttp in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (3.8.3)\n", | |
"Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (4.64.1)\n", | |
"Requirement already satisfied: requests>=2.19.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (2.28.1)\n", | |
"Requirement already satisfied: xxhash in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (3.2.0)\n", | |
"Requirement already satisfied: numpy>=1.17 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (1.21.5)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from datasets) (6.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.3.0)\n", | |
"Requirement already satisfied: filelock in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.6.0)\n", | |
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from packaging->datasets) (3.0.9)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3)\n", | |
"Requirement already satisfied: charset-normalizer<3,>=2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2.0.4)\n", | |
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (1.26.11)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from requests>=2.19.0->datasets) (2022.9.14)\n", | |
"Requirement already satisfied: colorama in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from tqdm>=4.62.1->datasets) (0.4.5)\n", | |
"Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.8.2)\n", | |
"Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n", | |
"Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (1.3.3)\n", | |
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (4.0.2)\n", | |
"Requirement already satisfied: attrs>=17.3.0 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (21.4.0)\n", | |
"Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from aiohttp->datasets) (6.0.4)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from pandas->datasets) (2.8.2)\n", | |
"Requirement already satisfied: pytz>=2020.1 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from pandas->datasets) (2022.1)\n", | |
"Requirement already satisfied: six>=1.5 in c:\\users\\maxim\\anaconda3\\lib\\site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install --pre torch torchtext --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117\n", | |
"!pip install datasets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "40ac2394", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import re\n", | |
"\n", | |
"\n", | |
"def truncate_string(string, length):\n", | |
" if len(string) > length:\n", | |
" return string[:length]\n", | |
" else:\n", | |
" return string\n", | |
"\n", | |
"\n", | |
"def remove_docstring_from_python_string(code):\n", | |
" # Use a regular expression to find all comments and docstrings\n", | |
" comments_and_docstrings = re.compile(r\"(\\\"\\\"\\\".*?\\\"\\\"\\\")|('''.*?''')\", re.DOTALL)\n", | |
"\n", | |
" # Remove the comments and docstrings from the code\n", | |
" cleaned_code = comments_and_docstrings.sub(\"\", code)\n", | |
"\n", | |
" return cleaned_code\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0244a970", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import csv\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"\n", | |
"def save_list_of_tuples_to_csv(filename, list_of_tuples, x_header='x', y_header='y'):\n", | |
" with open(filename, 'w', newline='', encoding='utf-8') as csvfile:\n", | |
" csv.QUOTE_ALL = True\n", | |
" writer = csv.writer(csvfile, delimiter=',', quotechar='|', escapechar='%')\n", | |
" writer.writerow([x_header, y_header])\n", | |
" for x, y in list_of_tuples:\n", | |
" writer.writerow([x, y])\n", | |
"\n", | |
"\n", | |
"def load_list_of_tuples_from_csv(filename):\n", | |
" dataset = []\n", | |
" with open(filename, newline='', encoding='utf-8') as csvfile:\n", | |
" csv.QUOTE_ALL = True\n", | |
" reader = csv.reader(csvfile, delimiter=',', quotechar='|', escapechar='%')\n", | |
" next(reader) # Skip the header row\n", | |
" for row in reader:\n", | |
" if len(row) > 1:\n", | |
" x = row[0]\n", | |
" y = row[1]\n", | |
" dataset.append((x, y))\n", | |
" return dataset\n", | |
"\n", | |
"\n", | |
"def generate_csv_dataset_from_huggingface(dataset, dataset_filter, x_header,\n", | |
" y_header, train_count, validation_count, test_count, train_csv_filename,\n", | |
" validation_csv_filename, test_csv_filename,\n", | |
" prepare_x_function=None, prepare_y_function=None):\n", | |
" if train_count > 0:\n", | |
" train_data = []\n", | |
" train_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"train\")\n", | |
" t = 0\n", | |
" for i in iter(train_dataset_iter):\n", | |
" if t > train_count:\n", | |
" break\n", | |
"\n", | |
" if prepare_x_function is not None:\n", | |
" x = prepare_x_function(i[x_header])\n", | |
" else:\n", | |
" x = i[x_header]\n", | |
"\n", | |
" if prepare_y_function is not None:\n", | |
" y = prepare_y_function(i[y_header])\n", | |
" else:\n", | |
" y = i[y_header]\n", | |
"\n", | |
" train_data.append((x, y))\n", | |
" t += 1\n", | |
"\n", | |
" save_list_of_tuples_to_csv(\n", | |
" train_csv_filename, train_data, x_header, y_header)\n", | |
"\n", | |
" if validation_count > 0:\n", | |
" validation_data = []\n", | |
" validation_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"validation\")\n", | |
" t = 0\n", | |
" for i in iter(validation_dataset_iter):\n", | |
" if t > validation_count:\n", | |
" break\n", | |
"\n", | |
" if prepare_x_function is not None:\n", | |
" x = prepare_x_function(i[x_header])\n", | |
" else:\n", | |
" x = i[x_header]\n", | |
"\n", | |
" if prepare_y_function is not None:\n", | |
" y = prepare_y_function(i[y_header])\n", | |
" else:\n", | |
" y = i[y_header]\n", | |
"\n", | |
" validation_data.append((x, y))\n", | |
" t += 1\n", | |
"\n", | |
" save_list_of_tuples_to_csv(\n", | |
" validation_csv_filename, validation_data, x_header,\n", | |
" y_header)\n", | |
"\n", | |
" if test_count > 0:\n", | |
" test_data = []\n", | |
" test_dataset_iter = load_dataset(dataset, dataset_filter, streaming=True, split=\"test\")\n", | |
" t = 0\n", | |
" for i in iter(test_dataset_iter):\n", | |
" if t > test_count:\n", | |
" break\n", | |
"\n", | |
" if prepare_x_function is not None:\n", | |
" x = prepare_x_function(i[x_header])\n", | |
" else:\n", | |
" x = i[x_header]\n", | |
"\n", | |
" if prepare_y_function is not None:\n", | |
" y = prepare_y_function(i[y_header])\n", | |
" else:\n", | |
" y = i[y_header]\n", | |
"\n", | |
" test_data.append((x, y))\n", | |
" t += 1\n", | |
"\n", | |
" save_list_of_tuples_to_csv(\n", | |
" test_csv_filename, test_data, x_header, y_header)\n", | |
"\n", | |
"\n", | |
"def load_csv_data_truncated(train_csv_filename, validation_csv_filename, test_csv_filename, max_sequence_length):\n", | |
" train_data = load_list_of_tuples_from_csv(train_csv_filename)\n", | |
" validation_data = load_list_of_tuples_from_csv(validation_csv_filename)\n", | |
" test_data = load_list_of_tuples_from_csv(test_csv_filename)\n", | |
" cleaned_train_data = []\n", | |
"\n", | |
" for desc, code in train_data:\n", | |
" description = truncate_string(desc, max_sequence_length)\n", | |
" function = truncate_string(code, max_sequence_length)\n", | |
" cleaned_train_data.append((description, function))\n", | |
"\n", | |
" cleaned_validation_data = []\n", | |
"\n", | |
" for desc, code in validation_data:\n", | |
" description = truncate_string(desc, max_sequence_length)\n", | |
" function = truncate_string(code, max_sequence_length)\n", | |
" cleaned_validation_data.append((description, function))\n", | |
"\n", | |
" cleaned_test_data = []\n", | |
"\n", | |
" for desc, code in test_data:\n", | |
" description = truncate_string(desc, max_sequence_length)\n", | |
" function = truncate_string(code, max_sequence_length)\n", | |
" cleaned_test_data.append((description, function))\n", | |
"\n", | |
" return cleaned_train_data, cleaned_validation_data, cleaned_test_data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "6762bce7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing import Tuple, List\n", | |
"from torch.utils.data import Dataset\n", | |
"\n", | |
"\n", | |
"class PyTorchDataset(Dataset):\n", | |
" def __init__(self, dataset: List[Tuple], **kwargs):\n", | |
" super().__init__(**kwargs)\n", | |
" self.dataset = dataset\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.dataset)\n", | |
"\n", | |
" def __getitem__(self, idx):\n", | |
" return self.dataset[idx][0], self.dataset[idx][1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "64c785fd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"import math\n", | |
"\n", | |
"import torch\n", | |
"from torch import nn as nn, Tensor\n", | |
"from torch.nn import Transformer\n", | |
"\n", | |
"\n", | |
"class PositionalEncoding(nn.Module):\n", | |
" def __init__(self,\n", | |
" emb_size: int,\n", | |
" dropout: float,\n", | |
" maxlen: int = 5000):\n", | |
" super(PositionalEncoding, self).__init__()\n", | |
" den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)\n", | |
" pos = torch.arange(0, maxlen).reshape(maxlen, 1)\n", | |
" pos_embedding = torch.zeros((maxlen, emb_size))\n", | |
" pos_embedding[:, 0::2] = torch.sin(pos * den)\n", | |
" pos_embedding[:, 1::2] = torch.cos(pos * den)\n", | |
" pos_embedding = pos_embedding.unsqueeze(-2)\n", | |
"\n", | |
" self.dropout = nn.Dropout(dropout)\n", | |
" self.register_buffer('pos_embedding', pos_embedding)\n", | |
"\n", | |
" def forward(self, token_embedding: Tensor):\n", | |
" return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])\n", | |
"\n", | |
"\n", | |
"class TokenEmbedding(nn.Module):\n", | |
" def __init__(self, vocab_size: int, emb_size):\n", | |
" super(TokenEmbedding, self).__init__()\n", | |
" self.embedding = nn.Embedding(vocab_size, emb_size)\n", | |
" self.emb_size = emb_size\n", | |
"\n", | |
" def forward(self, tokens: Tensor):\n", | |
" return self.embedding(tokens.long()) * math.sqrt(self.emb_size)\n", | |
"\n", | |
"\n", | |
"class Seq2SeqTransformer(nn.Module):\n", | |
" def __init__(self,\n", | |
" num_encoder_layers: int,\n", | |
" num_decoder_layers: int,\n", | |
" d_model: int,\n", | |
" num_heads: int,\n", | |
" src_vocab_size: int,\n", | |
" tgt_vocab_size: int,\n", | |
" dim_feedforward: int,\n", | |
" dropout: float = 0.1):\n", | |
" super(Seq2SeqTransformer, self).__init__()\n", | |
" self.transformer = Transformer(d_model=d_model,\n", | |
" nhead=num_heads,\n", | |
" num_encoder_layers=num_encoder_layers,\n", | |
" num_decoder_layers=num_decoder_layers,\n", | |
" dim_feedforward=dim_feedforward,\n", | |
" dropout=dropout)\n", | |
" self.generator = nn.Linear(d_model, tgt_vocab_size)\n", | |
" self.src_tok_emb = TokenEmbedding(src_vocab_size, d_model)\n", | |
" self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, d_model)\n", | |
" self.positional_encoding = PositionalEncoding(\n", | |
" d_model, dropout=dropout)\n", | |
"\n", | |
" def forward(self,\n", | |
" src: Tensor,\n", | |
" trg: Tensor,\n", | |
" src_mask: Tensor,\n", | |
" tgt_mask: Tensor,\n", | |
" src_padding_mask: Tensor,\n", | |
" tgt_padding_mask: Tensor,\n", | |
" memory_key_padding_mask: Tensor):\n", | |
" src_emb = self.positional_encoding(self.src_tok_emb(src))\n", | |
" tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))\n", | |
" outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,\n", | |
" src_padding_mask, tgt_padding_mask, memory_key_padding_mask)\n", | |
" return self.generator(outs)\n", | |
"\n", | |
" def encode(self, src: Tensor, src_mask: Tensor):\n", | |
" return self.transformer.encoder(self.positional_encoding(\n", | |
" self.src_tok_emb(src)), src_mask)\n", | |
"\n", | |
" def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):\n", | |
" return self.transformer.decoder(self.positional_encoding(\n", | |
" self.tgt_tok_emb(tgt)), memory, tgt_mask)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ef96ce13", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchtext.data import get_tokenizer\n", | |
"from torchtext.vocab import build_vocab_from_iterator\n", | |
"import pickle\n", | |
"from typing import Iterable, List\n", | |
"import tokenize\n", | |
"from io import BytesIO\n", | |
"\n", | |
"\n", | |
"def tokenize_code(code_string):\n", | |
" code_string = code_string + \"\\n\"\n", | |
" code_string_bytes = code_string.encode()\n", | |
" code = BytesIO(code_string_bytes)\n", | |
" tokens = []\n", | |
" i = 0\n", | |
" try:\n", | |
" for token_info in tokenize.tokenize(code.readline):\n", | |
" token_type = token_info[0]\n", | |
" token_string = token_info[1]\n", | |
" if i == 0:\n", | |
" if token_string == 'utf-8':\n", | |
" continue\n", | |
" if token_type == tokenize.NEWLINE:\n", | |
" tokens.append('NEWLINE')\n", | |
" elif token_type == tokenize.INDENT:\n", | |
" tokens.append('INDENT')\n", | |
" elif token_type == tokenize.DEDENT:\n", | |
" tokens.append('DEDENT')\n", | |
" else:\n", | |
" # This is a regular token\n", | |
" tokens.append(token_string)\n", | |
" i += 1\n", | |
" except tokenize.TokenError:\n", | |
" return tokens\n", | |
"\n", | |
" return tokens\n", | |
"\n", | |
"\n", | |
"def load_token_and_vocab_transform(token_transform_filename, vocab_transform_filename):\n", | |
" return pickle.load(open(token_transform_filename, \"rb\")), pickle.load(\n", | |
" open(vocab_transform_filename, \"rb\"))\n", | |
"\n", | |
"\n", | |
"def save_token_and_vocab_transform(token_transform, vocab_transform, token_transform_filename,\n", | |
" vocab_transform_filename):\n", | |
" pickle.dump(token_transform, open(token_transform_filename, \"wb\"))\n", | |
" pickle.dump(vocab_transform, open(vocab_transform_filename, \"wb\"))\n", | |
"\n", | |
"\n", | |
"def get_token_and_vocab_transform_and_special_token_ids(src_language, tgt_language, data_iterator):\n", | |
" vocab_transform = {}\n", | |
" token_transform = {src_language: get_tokenizer('spacy', language='en_core_web_sm'),\n", | |
" tgt_language: tokenize_code}\n", | |
"\n", | |
" # helper function to yield list of tokens\n", | |
" def yield_tokens(data_iter: Iterable, language: str) -> List[str]:\n", | |
" language_index = {src_language: 0, tgt_language: 1}\n", | |
"\n", | |
" for data_sample in data_iter:\n", | |
" yield token_transform[language](data_sample[language_index[language]])\n", | |
"\n", | |
" UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3\n", | |
" # Make sure the tokens are in order of their indices to properly insert them in vocab\n", | |
" special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']\n", | |
"\n", | |
" for ln in [src_language, tgt_language]:\n", | |
" # Create torchtext's Vocab object\n", | |
" vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(data_iterator, ln),\n", | |
" min_freq=1,\n", | |
" specials=special_symbols,\n", | |
" special_first=True)\n", | |
"\n", | |
" for ln in [src_language, tgt_language]:\n", | |
" vocab_transform[ln].set_default_index(UNK_IDX)\n", | |
"\n", | |
" return token_transform, vocab_transform, UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX\n", | |
"\n", | |
"\n", | |
"# helper function to club together sequential operations\n", | |
"def sequential_transforms(*transforms):\n", | |
" def func(txt_input):\n", | |
" for transform in transforms:\n", | |
" txt_input = transform(txt_input)\n", | |
" return txt_input\n", | |
"\n", | |
" return func\n", | |
"\n", | |
"\n", | |
"def ConvertToTokenIds(dataset, x_text_transform, y_text_transform):\n", | |
" new_dataset = []\n", | |
" for src_sample, tgt_sample in dataset:\n", | |
" new_dataset.append((x_text_transform(src_sample.rstrip(\"\\n\")), y_text_transform(tgt_sample.rstrip(\"\\n\"))))\n", | |
"\n", | |
" return new_dataset\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c44182e7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"from timeit import default_timer as timer\n", | |
"from typing import List\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from torch.nn.utils.rnn import pad_sequence\n", | |
"from torch.utils.data import DataLoader\n", | |
"from tqdm import tqdm\n", | |
"\n", | |
"torch.manual_seed(42)\n", | |
"DEVICE = torch.device('cuda')\n", | |
"\n", | |
"SRC_LANGUAGE = 'description'\n", | |
"TGT_LANGUAGE = 'code'\n", | |
"\n", | |
"generate_csv_dataset = True\n", | |
"max_sequence_length = 64000\n", | |
"\n", | |
"train_count = 100000\n", | |
"validation_count = 1000\n", | |
"test_count = 500\n", | |
"\n", | |
"train_csv_filename = f'./code_search_net_train_{train_count}.csv'\n", | |
"validation_csv_filename = f'./code_search_net_validation_{validation_count}.csv'\n", | |
"test_csv_filename = f'./code_search_net_test_{test_count}.csv'\n", | |
"\n", | |
"if generate_csv_dataset:\n", | |
" generate_csv_dataset_from_huggingface(dataset=\"code_search_net\", dataset_filter=\"python\",\n", | |
" x_header='func_documentation_string',\n", | |
" y_header='func_code_string', train_count=train_count,\n", | |
" validation_count=validation_count, test_count=test_count,\n", | |
" train_csv_filename=train_csv_filename,\n", | |
" validation_csv_filename=validation_csv_filename,\n", | |
" test_csv_filename=test_csv_filename,\n", | |
" prepare_x_function=None,\n", | |
" prepare_y_function=remove_docstring_from_python_string)\n", | |
"\n", | |
"truncated_train_data, truncated_validation_data, truncated_test_data = load_csv_data_truncated(train_csv_filename,\n", | |
" validation_csv_filename,\n", | |
" test_csv_filename,\n", | |
" max_sequence_length)\n", | |
"\n", | |
"train_dataset_for_vocab_and_token_transform_generation = PyTorchDataset(truncated_train_data)\n", | |
"\n", | |
"token_transform, vocab_transform, UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = get_token_and_vocab_transform_and_special_token_ids(\n", | |
" SRC_LANGUAGE, TGT_LANGUAGE, train_dataset_for_vocab_and_token_transform_generation)\n", | |
"\n", | |
"# for src_sample, tgt_sample in train_dataset_for_vocab_and_token_transform_generation:\n", | |
"# print((token_transform[SRC_LANGUAGE](src_sample), token_transform[TGT_LANGUAGE](tgt_sample)))\n", | |
"\n", | |
"\n", | |
"# Add BOS/EOS and create tensor for input sequence indices\n", | |
"def tensor_transform(token_ids: List[int]):\n", | |
" return torch.cat((torch.tensor([BOS_IDX]),\n", | |
" torch.tensor(token_ids),\n", | |
" torch.tensor([EOS_IDX])))\n", | |
"\n", | |
"\n", | |
"# Text transforms to convert raw strings into tensors indices\n", | |
"text_transform = {}\n", | |
"for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:\n", | |
" text_transform[ln] = sequential_transforms(token_transform[ln], # Tokenization\n", | |
" vocab_transform[ln], # Numericalization\n", | |
" tensor_transform) # Add BOS/EOS and create tensor\n", | |
"\n", | |
"truncated_train_data = ConvertToTokenIds(truncated_train_data, text_transform[SRC_LANGUAGE],\n", | |
" text_transform[TGT_LANGUAGE])\n", | |
"truncated_validation_data = ConvertToTokenIds(truncated_validation_data, text_transform[SRC_LANGUAGE],\n", | |
" text_transform[TGT_LANGUAGE])\n", | |
"truncated_test_data = ConvertToTokenIds(truncated_test_data, text_transform[SRC_LANGUAGE], text_transform[TGT_LANGUAGE])\n", | |
"\n", | |
"\n", | |
"def generate_square_subsequent_mask(sz):\n", | |
" mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)\n", | |
" mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n", | |
" return mask\n", | |
"\n", | |
"\n", | |
"# Create attention masks.\n", | |
"def create_mask(src, tgt):\n", | |
" src_seq_len = src.shape[0]\n", | |
" tgt_seq_len = tgt.shape[0]\n", | |
"\n", | |
" tgt_mask = generate_square_subsequent_mask(tgt_seq_len)\n", | |
" src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)\n", | |
"\n", | |
" src_padding_mask = (src == PAD_IDX).transpose(0, 1)\n", | |
" tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)\n", | |
" return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask\n", | |
"\n", | |
"\n", | |
"# Collate data samples into batch tensors\n", | |
"def collate_fn(batch):\n", | |
" src_batch, tgt_batch = [], []\n", | |
" for src_sample, tgt_sample in batch:\n", | |
" src_batch.append(src_sample)\n", | |
" tgt_batch.append(tgt_sample)\n", | |
"\n", | |
" src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)\n", | |
" tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)\n", | |
" return src_batch, tgt_batch\n", | |
"\n", | |
"\n", | |
"def perform_training(model, optimizer, loss_fn, batch_size):\n", | |
" model.train()\n", | |
" losses = 0\n", | |
" train_iter = PyTorchDataset(truncated_train_data)\n", | |
" train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)\n", | |
" with tqdm(total=int(len(train_iter) / batch_size)) as pbar:\n", | |
" for src, tgt in train_dataloader:\n", | |
" src = src.to(DEVICE)\n", | |
" src = src.long()\n", | |
" tgt = tgt.to(DEVICE)\n", | |
" tgt = tgt.long()\n", | |
" tgt_input = tgt[:-1, :]\n", | |
"\n", | |
" src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n", | |
"\n", | |
" logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" tgt_out = tgt[1:, :]\n", | |
" loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n", | |
" loss.backward()\n", | |
"\n", | |
" optimizer.step()\n", | |
" losses += loss.item()\n", | |
" pbar.update(1)\n", | |
"\n", | |
" return losses / len(train_dataloader)\n", | |
"\n", | |
"\n", | |
"def perform_validation(model, loss_fn, batch_size):\n", | |
" model.eval()\n", | |
" losses = 0\n", | |
" val_iter = PyTorchDataset(truncated_validation_data)\n", | |
" val_dataloader = DataLoader(val_iter, collate_fn=collate_fn, batch_size=batch_size)\n", | |
" with tqdm(total=int(len(val_iter) / batch_size)) as pbar:\n", | |
" for src, tgt in val_dataloader:\n", | |
" src = src.to(DEVICE)\n", | |
" src = src.long()\n", | |
" tgt = tgt.to(DEVICE)\n", | |
" tgt = tgt.long()\n", | |
"\n", | |
" tgt_input = tgt[:-1, :]\n", | |
"\n", | |
" src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)\n", | |
"\n", | |
" logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)\n", | |
"\n", | |
" tgt_out = tgt[1:, :]\n", | |
" loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))\n", | |
" losses += loss.item()\n", | |
" pbar.update(1)\n", | |
" return losses / len(val_dataloader)\n", | |
"\n", | |
"\n", | |
"def lr_scheduler(step_num, d_model, warmup_steps=4000):\n", | |
" if step_num == 0:\n", | |
" return d_model ** -0.5\n", | |
" # Linearly increasing the learning rate for the first warmup_steps, and\n", | |
" # decreasing it thereafter\n", | |
" arg1 = step_num ** -0.5\n", | |
" arg2 = step_num * (warmup_steps ** -1.5)\n", | |
"\n", | |
" return (d_model ** -0.5) * min(arg1, arg2)\n", | |
"\n", | |
"\n", | |
"def accuracy_fcn(target, prediction):\n", | |
" # Find equal prediction and target values, and apply the padding mask\n", | |
" accuracy = (target == torch.argmax(prediction, dim=2)).float()\n", | |
" mask = (target != 0).float()\n", | |
"\n", | |
" return torch.sum(accuracy * mask) / torch.sum(mask)\n", | |
"\n", | |
"\n", | |
"def fit_transformer_model(transformer_model, num_epochs, batch_size):\n", | |
" loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)\n", | |
" optimizer = torch.optim.NAdam(transformer_model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 )\n", | |
" # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_scheduler(step, d_model))\n", | |
"\n", | |
" for epoch in range(1, num_epochs + 1):\n", | |
" start_time = timer()\n", | |
" train_loss = perform_training(transformer_model, optimizer, loss_fn, batch_size)\n", | |
" end_time = timer()\n", | |
" val_loss = perform_validation(transformer_model, loss_fn, batch_size)\n", | |
" print((\n", | |
" f\"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, \"f\"Epoch time = {(end_time - start_time):.3f}s\"))\n", | |
" if epoch % 5 == 0:\n", | |
" torch.save(transformer_model.state_dict(), f\"./checkpoint_at_epoch_{epoch}.pt\")\n", | |
" print(f\"Saved checkpoint at epoch {epoch}\")\n", | |
" # Update the learning rate\n", | |
" # scheduler.step()\n", | |
"\n", | |
"\n", | |
"def count_parameters(model):\n", | |
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
"\n", | |
"\n", | |
"# function to generate output sequence using greedy algorithm\n", | |
"def greedy_decode(model, src, src_mask, max_len, start_symbol):\n", | |
" src = src.to(DEVICE)\n", | |
" src_mask = src_mask.to(DEVICE)\n", | |
"\n", | |
" memory = model.encode(src, src_mask)\n", | |
" ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)\n", | |
" for i in range(max_len - 1):\n", | |
" memory = memory.to(DEVICE)\n", | |
" tgt_mask = (generate_square_subsequent_mask(ys.size(0))\n", | |
" .type(torch.bool)).to(DEVICE)\n", | |
" out = model.decode(ys, memory, tgt_mask)\n", | |
" out = out.transpose(0, 1)\n", | |
" prob = model.generator(out[:, -1])\n", | |
" _, next_word = torch.max(prob, dim=1)\n", | |
" next_word = next_word.item()\n", | |
"\n", | |
" ys = torch.cat([ys,\n", | |
" torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)\n", | |
" if next_word == EOS_IDX:\n", | |
" break\n", | |
" return ys\n", | |
"\n", | |
"\n", | |
"# actual function to translate input sentence into target language\n", | |
"def translate(model: torch.nn.Module, src_sentence: str):\n", | |
" model.eval()\n", | |
" src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)\n", | |
" num_tokens = src.shape[0]\n", | |
" src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)\n", | |
" tgt_tokens = greedy_decode(\n", | |
" model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()\n", | |
" return \" \".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace(\"\",\n", | |
" \"\").replace(\n", | |
" \"\", \"\")\n", | |
"\n", | |
"\n", | |
"def test_transformer(transformer_model, sentence):\n", | |
" print(translate(transformer_model, sentence))\n", | |
"\n", | |
"\n", | |
"def train_transformer_model(fit_model, load_model, test_model):\n", | |
" src_vocab_size = len(vocab_transform[SRC_LANGUAGE])\n", | |
" tgt_vocab_size = len(vocab_transform[TGT_LANGUAGE])\n", | |
" d_model = 512\n", | |
" num_heads = 8\n", | |
" feed_forward_dim = 2048\n", | |
" num_encoder_layers = 6\n", | |
" num_decoder_layers = 6\n", | |
"\n", | |
" batch_size = 2\n", | |
" num_epochs = 120\n", | |
"\n", | |
" transformer = Seq2SeqTransformer(num_encoder_layers, num_decoder_layers, d_model,\n", | |
" num_heads, src_vocab_size, tgt_vocab_size, feed_forward_dim)\n", | |
"\n", | |
" if load_model:\n", | |
" transformer.load_state_dict(torch.load(\"./checkpoint_at_epoch_25.pt\"))\n", | |
" else:\n", | |
" for p in transformer.parameters():\n", | |
" if p.dim() > 1:\n", | |
" nn.init.xavier_uniform_(p)\n", | |
"\n", | |
" print(f'The model has {count_parameters(transformer):,} trainable parameters')\n", | |
" opt_transformer = torch.compile(transformer, backend='inductor')\n", | |
" opt_transformer = opt_transformer.to(DEVICE)\n", | |
" if fit_model:\n", | |
" fit_transformer_model(transformer, num_epochs, batch_size)\n", | |
"\n", | |
" if test_model:\n", | |
" test_transformer(transformer, 'Downloads Dailymotion videos by URL.')\n", | |
"\n", | |
"\n", | |
"train_transformer_model(True, False, True)\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment