Created
January 25, 2025 13:27
-
-
Save klazuka/a9298037da775e2121fb1ea9cef5dd1b to your computer and use it in GitHub Desktop.
ChatGpt Operator wrote this classifier in Jupyter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "8570cae5-1e95-4871-9d7f-b7d8c9258438", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Using device = cuda:0\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"# Check if CUDA is available\n", | |
"device = torch.device('cpu')\n", | |
"if torch.cuda.is_available():\n", | |
" device = torch.device('cuda')\n", | |
"\n", | |
"torch.set_default_device(device)\n", | |
"print(f\"Using device = {torch.get_default_device()}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "9b474f6d-a56e-4478-8676-4553d0327f17", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"--2025-01-25 12:36:36-- https://download.pytorch.org/tutorial/data.zip\n", | |
"Resolving download.pytorch.org (download.pytorch.org)... 18.238.152.14, 18.238.152.71, 18.238.152.47, ...\n", | |
"Connecting to download.pytorch.org (download.pytorch.org)|18.238.152.14|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 2882130 (2.7M) [application/zip]\n", | |
"Saving to: ‘data.zip’\n", | |
"\n", | |
"data.zip 100%[===================>] 2.75M 9.82MB/s in 0.3s \n", | |
"\n", | |
"2025-01-25 12:36:36 (9.82 MB/s) - ‘data.zip’ saved [2882130/2882130]\n", | |
"\n", | |
"/bin/bash: line 1: unzip: command not found\n" | |
] | |
} | |
], | |
"source": [ | |
"!wget https://download.pytorch.org/tutorial/data.zip -O data.zip\n", | |
"!unzip data.zip" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "5085dfb3-281e-422b-af06-fec82a5feb2e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import zipfile\n", | |
"\n", | |
"with zipfile.ZipFile('data.zip', 'r') as zip_ref:\n", | |
" zip_ref.extractall('.')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "15005de1-79c0-413c-b228-e1dd7d556602", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"data data.zip\tname-classifier.ipynb\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "49eae2ec-fafe-40a5-82d1-4c1e04cffeb6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"eng-fra.txt names\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "905664cf-e7ef-45c9-85bb-49c5c24ce0ab", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Arabic.txt English.txt Irish.txt\tPolish.txt\tSpanish.txt\n", | |
"Chinese.txt French.txt Italian.txt\tPortuguese.txt\tVietnamese.txt\n", | |
"Czech.txt German.txt Japanese.txt\tRussian.txt\n", | |
"Dutch.txt Greek.txt\t Korean.txt\tScottish.txt\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls data/names" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "0d1944f7-e394-40d2-a918-b9cc76c9a6a2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>name</th>\n", | |
" <th>language</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Nguyen</td>\n", | |
" <td>Vietnamese</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Tron</td>\n", | |
" <td>Vietnamese</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>Le</td>\n", | |
" <td>Vietnamese</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>Pham</td>\n", | |
" <td>Vietnamese</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>Huynh</td>\n", | |
" <td>Vietnamese</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" name language\n", | |
"0 Nguyen Vietnamese\n", | |
"1 Tron Vietnamese\n", | |
"2 Le Vietnamese\n", | |
"3 Pham Vietnamese\n", | |
"4 Huynh Vietnamese" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import pandas as pd\n", | |
"\n", | |
"# Create a list to store the data\n", | |
"data = []\n", | |
"\n", | |
"# Iterate over each language file in the 'names' directory\n", | |
"for filename in os.listdir('data/names'):\n", | |
" if filename.endswith('.txt'):\n", | |
" language = filename.split('.')[0]\n", | |
" with open(os.path.join('data/names', filename), 'r', encoding='utf-8') as file:\n", | |
" names = file.read().strip().split('\\n')\n", | |
" for name in names:\n", | |
" data.append((name, language))\n", | |
"\n", | |
"# Create a DataFrame from the data\n", | |
"df = pd.DataFrame(data, columns=['name', 'language'])\n", | |
"\n", | |
"df.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "b5b8ce40-87ed-4e24-8988-bc0278baa6a8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.3)\n", | |
"Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (1.26.3)\n", | |
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.9.0.post0)\n", | |
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.2)\n", | |
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.1)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", | |
"\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install pandas" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "d4455916-89b4-496c-a5b6-dcf8c94a6684", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"({'input_ids': tensor([[ 0, 28524, 3509, ..., 1, 1, 1],\n", | |
" [ 0, 9064, 37861, ..., 1, 1, 1],\n", | |
" [ 0, 104, 1250, ..., 1, 1, 1],\n", | |
" ...,\n", | |
" [ 0, 10350, 1452, ..., 1, 1, 1],\n", | |
" [ 0, 19897, 10790, ..., 1, 1, 1],\n", | |
" [ 0, 10643, 26153, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0]], device='cuda:0')},\n", | |
" {'input_ids': tensor([[ 0, 975, 1906, ..., 1, 1, 1],\n", | |
" [ 0, 495, 967, ..., 1, 1, 1],\n", | |
" [ 0, 846, 2462, ..., 1, 1, 1],\n", | |
" ...,\n", | |
" [ 0, 43517, 1627, ..., 1, 1, 1],\n", | |
" [ 0, 33282, 32823, ..., 1, 1, 1],\n", | |
" [ 0, 725, 1397, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0]], device='cuda:0')})" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.preprocessing import LabelEncoder\n", | |
"from transformers import RobertaTokenizer\n", | |
"\n", | |
"# Split the data into training and validation sets\n", | |
"train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)\n", | |
"\n", | |
"# Encode the language labels\n", | |
"label_encoder = LabelEncoder()\n", | |
"train_df['language_encoded'] = label_encoder.fit_transform(train_df['language'])\n", | |
"val_df['language_encoded'] = label_encoder.transform(val_df['language'])\n", | |
"\n", | |
"# Initialize the RoBERTa tokenizer\n", | |
"tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n", | |
"\n", | |
"def tokenize_names(names):\n", | |
" return tokenizer(names.tolist(), padding=True, truncation=True, return_tensors='pt')\n", | |
"\n", | |
"# Tokenize the names in the training and validation sets\n", | |
"train_encodings = tokenize_names(train_df['name']).to(device=device)\n", | |
"val_encodings = tokenize_names(val_df['name']).to(device=device)\n", | |
"\n", | |
"train_encodings, val_encodings" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "1c6377a0-57e5-4ab3-a825-25803c90cac5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", | |
"Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.26.3)\n", | |
"Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.15.1)\n", | |
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n", | |
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", | |
"\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install scikit-learn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "63bbe926-746d-467e-b2de-097a7bc2a7c3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.48.1)\n", | |
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.13.1)\n", | |
"Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.27.1)\n", | |
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.3)\n", | |
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n", | |
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n", | |
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n", | |
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n", | |
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.0)\n", | |
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n", | |
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n", | |
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2024.2.0)\n", | |
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.9.0)\n", | |
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n", | |
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n", | |
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.3)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.8.30)\n", | |
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", | |
"\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", | |
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install transformers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "644aa878-b87a-4dc4-9118-bae316c5a00e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(<torch.utils.data.dataloader.DataLoader at 0x7b994c4be810>,\n", | |
" <torch.utils.data.dataloader.DataLoader at 0x7b994c4bd250>,\n", | |
" RobertaForSequenceClassification(\n", | |
" (roberta): RobertaModel(\n", | |
" (embeddings): RobertaEmbeddings(\n", | |
" (word_embeddings): Embedding(50265, 768, padding_idx=1)\n", | |
" (position_embeddings): Embedding(514, 768, padding_idx=1)\n", | |
" (token_type_embeddings): Embedding(1, 768)\n", | |
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (encoder): RobertaEncoder(\n", | |
" (layer): ModuleList(\n", | |
" (0-11): 12 x RobertaLayer(\n", | |
" (attention): RobertaAttention(\n", | |
" (self): RobertaSdpaSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (output): RobertaSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): RobertaIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" (intermediate_act_fn): GELUActivation()\n", | |
" )\n", | |
" (output): RobertaOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (classifier): RobertaClassificationHead(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out_proj): Linear(in_features=768, out_features=18, bias=True)\n", | |
" )\n", | |
" ))" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"from transformers import RobertaForSequenceClassification, AdamW\n", | |
"\n", | |
"class NameDataset(Dataset):\n", | |
" def __init__(self, encodings, labels):\n", | |
" self.encodings = encodings\n", | |
" self.labels = labels\n", | |
"\n", | |
" def __getitem__(self, idx):\n", | |
" item = {key: val[idx] for key, val in self.encodings.items()}\n", | |
" item['labels'] = torch.tensor(self.labels[idx], device=device)\n", | |
" return item\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.labels)\n", | |
"\n", | |
"# Create PyTorch datasets\n", | |
"train_dataset = NameDataset(train_encodings, train_df['language_encoded'].values)\n", | |
"val_dataset = NameDataset(val_encodings, val_df['language_encoded'].values)\n", | |
"\n", | |
"# Create DataLoaders for batching\n", | |
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, generator=torch.Generator(device=device))\n", | |
"val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, generator=torch.Generator(device=device))\n", | |
"\n", | |
"# Initialize the RoBERTa model for sequence classification\n", | |
"model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(label_encoder.classes_))\n", | |
"model.to(device)\n", | |
"\n", | |
"# Set up the optimizer\n", | |
"optimizer = AdamW(model.parameters(), lr=5e-5)\n", | |
"\n", | |
"train_loader, val_loader, model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "f9370ef9-a059-4ecb-8e42-b70051c82ed2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from tqdm import tqdm\n", | |
"\n", | |
"def train(model, train_loader, val_loader, optimizer, epochs=3):\n", | |
" print(f\"begin training for {epochs} epochs\")\n", | |
" model.train()\n", | |
" for epoch in range(epochs):\n", | |
" print(f'Epoch {epoch + 1}/{epochs}')\n", | |
" total_loss = 0\n", | |
" for batch in tqdm(train_loader):\n", | |
" optimizer.zero_grad()\n", | |
" outputs = model(**batch)\n", | |
" loss = outputs.loss\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" total_loss += loss.item()\n", | |
" avg_train_loss = total_loss / len(train_loader)\n", | |
" print(f'Average training loss: {avg_train_loss:.4f}')\n", | |
" evaluate(model, val_loader)\n", | |
"\n", | |
"\n", | |
"def evaluate(model, val_loader):\n", | |
" model.eval()\n", | |
" total_loss = 0\n", | |
" correct = 0\n", | |
" total = 0\n", | |
" with torch.no_grad():\n", | |
" for batch in tqdm(val_loader):\n", | |
" outputs = model(**batch)\n", | |
" loss = outputs.loss\n", | |
" total_loss += loss.item()\n", | |
" logits = outputs.logits\n", | |
" predictions = torch.argmax(logits, dim=-1)\n", | |
" correct += (predictions == batch['labels']).sum().item()\n", | |
" total += batch['labels'].size(0)\n", | |
" avg_val_loss = total_loss / len(val_loader)\n", | |
" accuracy = correct / total\n", | |
" print(f'Validation loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "5d095d8f-13bc-4133-82e8-3cd39d1027f3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"begin training for 3 epochs\n", | |
"Epoch 1/3\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1004/1004 [00:43<00:00, 23.24it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average training loss: 0.8538\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 251/251 [00:02<00:00, 107.04it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Validation loss: 0.5905, Accuracy: 0.8194\n", | |
"Epoch 2/3\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1004/1004 [00:40<00:00, 24.77it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average training loss: 0.4990\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 251/251 [00:02<00:00, 110.31it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Validation loss: 0.5298, Accuracy: 0.8411\n", | |
"Epoch 3/3\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1004/1004 [00:41<00:00, 24.16it/s]\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average training loss: 0.3373\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 251/251 [00:02<00:00, 108.09it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Validation loss: 0.5763, Accuracy: 0.8314\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"train(model, train_loader, val_loader, optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "a1956c49-2736-43ed-bf59-db9b421594b9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The predicted language for the name \"John\" is: ['English']\n", | |
"['English']\n" | |
] | |
} | |
], | |
"source": [ | |
"def predict_language(name):\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" encoding = tokenizer(name, return_tensors='pt', padding=True, truncation=True)\n", | |
" outputs = model(**encoding)\n", | |
" logits = outputs.logits\n", | |
" prediction = torch.argmax(logits, dim=-1).item()\n", | |
" language = label_encoder.inverse_transform([prediction])\n", | |
" return language\n", | |
"\n", | |
"# Test the model with a sample name\n", | |
"sample_name = \"John\"\n", | |
"predicted_language = predict_language(sample_name)\n", | |
"print(f'The predicted language for the name \"{sample_name}\" is: {predicted_language}')\n", | |
"print(predicted_language)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "fc4ae875-5948-4d9a-b615-5f12b65713db", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The predicted language for the name \"Maria\" is: ['Italian']\n" | |
] | |
} | |
], | |
"source": [ | |
"# Test the model with a different name\n", | |
"sample_name = \"Maria\"\n", | |
"predicted_language = predict_language(sample_name)\n", | |
"print(f'The predicted language for the name \"{sample_name}\" is: {predicted_language}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "9ac1712e-c8c2-4812-b46c-d5f7aed6124f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array(['Russian'], dtype=object)" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"predict_language(\"Misha\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b9b6fb2d-aab9-48b2-9021-c2be7131fd65", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment