Skip to content

Instantly share code, notes, and snippets.

@klazuka
Created January 25, 2025 13:27
Show Gist options
  • Save klazuka/a9298037da775e2121fb1ea9cef5dd1b to your computer and use it in GitHub Desktop.
Save klazuka/a9298037da775e2121fb1ea9cef5dd1b to your computer and use it in GitHub Desktop.
ChatGpt Operator wrote this classifier in Jupyter
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "8570cae5-1e95-4871-9d7f-b7d8c9258438",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device = cuda:0\n"
]
}
],
"source": [
"import torch\n",
"# Check if CUDA is available\n",
"device = torch.device('cpu')\n",
"if torch.cuda.is_available():\n",
" device = torch.device('cuda')\n",
"\n",
"torch.set_default_device(device)\n",
"print(f\"Using device = {torch.get_default_device()}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b474f6d-a56e-4478-8676-4553d0327f17",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2025-01-25 12:36:36-- https://download.pytorch.org/tutorial/data.zip\n",
"Resolving download.pytorch.org (download.pytorch.org)... 18.238.152.14, 18.238.152.71, 18.238.152.47, ...\n",
"Connecting to download.pytorch.org (download.pytorch.org)|18.238.152.14|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 2882130 (2.7M) [application/zip]\n",
"Saving to: ‘data.zip’\n",
"\n",
"data.zip 100%[===================>] 2.75M 9.82MB/s in 0.3s \n",
"\n",
"2025-01-25 12:36:36 (9.82 MB/s) - ‘data.zip’ saved [2882130/2882130]\n",
"\n",
"/bin/bash: line 1: unzip: command not found\n"
]
}
],
"source": [
"!wget https://download.pytorch.org/tutorial/data.zip -O data.zip\n",
"!unzip data.zip"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5085dfb3-281e-422b-af06-fec82a5feb2e",
"metadata": {},
"outputs": [],
"source": [
"import zipfile\n",
"\n",
"with zipfile.ZipFile('data.zip', 'r') as zip_ref:\n",
" zip_ref.extractall('.')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "15005de1-79c0-413c-b228-e1dd7d556602",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data data.zip\tname-classifier.ipynb\n"
]
}
],
"source": [
"!ls"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "49eae2ec-fafe-40a5-82d1-4c1e04cffeb6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eng-fra.txt names\n"
]
}
],
"source": [
"!ls data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "905664cf-e7ef-45c9-85bb-49c5c24ce0ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arabic.txt English.txt Irish.txt\tPolish.txt\tSpanish.txt\n",
"Chinese.txt French.txt Italian.txt\tPortuguese.txt\tVietnamese.txt\n",
"Czech.txt German.txt Japanese.txt\tRussian.txt\n",
"Dutch.txt Greek.txt\t Korean.txt\tScottish.txt\n"
]
}
],
"source": [
"!ls data/names"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0d1944f7-e394-40d2-a918-b9cc76c9a6a2",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>language</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Nguyen</td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Tron</td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Le</td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Pham</td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Huynh</td>\n",
" <td>Vietnamese</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name language\n",
"0 Nguyen Vietnamese\n",
"1 Tron Vietnamese\n",
"2 Le Vietnamese\n",
"3 Pham Vietnamese\n",
"4 Huynh Vietnamese"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import pandas as pd\n",
"\n",
"# Create a list to store the data\n",
"data = []\n",
"\n",
"# Iterate over each language file in the 'names' directory\n",
"for filename in os.listdir('data/names'):\n",
" if filename.endswith('.txt'):\n",
" language = filename.split('.')[0]\n",
" with open(os.path.join('data/names', filename), 'r', encoding='utf-8') as file:\n",
" names = file.read().strip().split('\\n')\n",
" for name in names:\n",
" data.append((name, language))\n",
"\n",
"# Create a DataFrame from the data\n",
"df = pd.DataFrame(data, columns=['name', 'language'])\n",
"\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b5b8ce40-87ed-4e24-8988-bc0278baa6a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.3)\n",
"Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (1.26.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2025.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install pandas"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d4455916-89b4-496c-a5b6-dcf8c94a6684",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"({'input_ids': tensor([[ 0, 28524, 3509, ..., 1, 1, 1],\n",
" [ 0, 9064, 37861, ..., 1, 1, 1],\n",
" [ 0, 104, 1250, ..., 1, 1, 1],\n",
" ...,\n",
" [ 0, 10350, 1452, ..., 1, 1, 1],\n",
" [ 0, 19897, 10790, ..., 1, 1, 1],\n",
" [ 0, 10643, 26153, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]], device='cuda:0')},\n",
" {'input_ids': tensor([[ 0, 975, 1906, ..., 1, 1, 1],\n",
" [ 0, 495, 967, ..., 1, 1, 1],\n",
" [ 0, 846, 2462, ..., 1, 1, 1],\n",
" ...,\n",
" [ 0, 43517, 1627, ..., 1, 1, 1],\n",
" [ 0, 33282, 32823, ..., 1, 1, 1],\n",
" [ 0, 725, 1397, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]], device='cuda:0')})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from transformers import RobertaTokenizer\n",
"\n",
"# Split the data into training and validation sets\n",
"train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)\n",
"\n",
"# Encode the language labels\n",
"label_encoder = LabelEncoder()\n",
"train_df['language_encoded'] = label_encoder.fit_transform(train_df['language'])\n",
"val_df['language_encoded'] = label_encoder.transform(val_df['language'])\n",
"\n",
"# Initialize the RoBERTa tokenizer\n",
"tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n",
"\n",
"def tokenize_names(names):\n",
" return tokenizer(names.tolist(), padding=True, truncation=True, return_tensors='pt')\n",
"\n",
"# Tokenize the names in the training and validation sets\n",
"train_encodings = tokenize_names(train_df['name']).to(device=device)\n",
"val_encodings = tokenize_names(val_df['name']).to(device=device)\n",
"\n",
"train_encodings, val_encodings"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1c6377a0-57e5-4ab3-a825-25803c90cac5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
"Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.26.3)\n",
"Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.15.1)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "63bbe926-746d-467e-b2de-097a7bc2a7c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.48.1)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.27.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.3)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.0)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2024.2.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.9.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.8.30)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install transformers"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "644aa878-b87a-4dc4-9118-bae316c5a00e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"data": {
"text/plain": [
"(<torch.utils.data.dataloader.DataLoader at 0x7b994c4be810>,\n",
" <torch.utils.data.dataloader.DataLoader at 0x7b994c4bd250>,\n",
" RobertaForSequenceClassification(\n",
" (roberta): RobertaModel(\n",
" (embeddings): RobertaEmbeddings(\n",
" (word_embeddings): Embedding(50265, 768, padding_idx=1)\n",
" (position_embeddings): Embedding(514, 768, padding_idx=1)\n",
" (token_type_embeddings): Embedding(1, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): RobertaEncoder(\n",
" (layer): ModuleList(\n",
" (0-11): 12 x RobertaLayer(\n",
" (attention): RobertaAttention(\n",
" (self): RobertaSdpaSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): RobertaSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): RobertaIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): RobertaOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (classifier): RobertaClassificationHead(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out_proj): Linear(in_features=768, out_features=18, bias=True)\n",
" )\n",
" ))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from transformers import RobertaForSequenceClassification, AdamW\n",
"\n",
"class NameDataset(Dataset):\n",
" def __init__(self, encodings, labels):\n",
" self.encodings = encodings\n",
" self.labels = labels\n",
"\n",
" def __getitem__(self, idx):\n",
" item = {key: val[idx] for key, val in self.encodings.items()}\n",
" item['labels'] = torch.tensor(self.labels[idx], device=device)\n",
" return item\n",
"\n",
" def __len__(self):\n",
" return len(self.labels)\n",
"\n",
"# Create PyTorch datasets\n",
"train_dataset = NameDataset(train_encodings, train_df['language_encoded'].values)\n",
"val_dataset = NameDataset(val_encodings, val_df['language_encoded'].values)\n",
"\n",
"# Create DataLoaders for batching\n",
"train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, generator=torch.Generator(device=device))\n",
"val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, generator=torch.Generator(device=device))\n",
"\n",
"# Initialize the RoBERTa model for sequence classification\n",
"model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=len(label_encoder.classes_))\n",
"model.to(device)\n",
"\n",
"# Set up the optimizer\n",
"optimizer = AdamW(model.parameters(), lr=5e-5)\n",
"\n",
"train_loader, val_loader, model"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "f9370ef9-a059-4ecb-8e42-b70051c82ed2",
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"def train(model, train_loader, val_loader, optimizer, epochs=3):\n",
" print(f\"begin training for {epochs} epochs\")\n",
" model.train()\n",
" for epoch in range(epochs):\n",
" print(f'Epoch {epoch + 1}/{epochs}')\n",
" total_loss = 0\n",
" for batch in tqdm(train_loader):\n",
" optimizer.zero_grad()\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" avg_train_loss = total_loss / len(train_loader)\n",
" print(f'Average training loss: {avg_train_loss:.4f}')\n",
" evaluate(model, val_loader)\n",
"\n",
"\n",
"def evaluate(model, val_loader):\n",
" model.eval()\n",
" total_loss = 0\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for batch in tqdm(val_loader):\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" total_loss += loss.item()\n",
" logits = outputs.logits\n",
" predictions = torch.argmax(logits, dim=-1)\n",
" correct += (predictions == batch['labels']).sum().item()\n",
" total += batch['labels'].size(0)\n",
" avg_val_loss = total_loss / len(val_loader)\n",
" accuracy = correct / total\n",
" print(f'Validation loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "5d095d8f-13bc-4133-82e8-3cd39d1027f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"begin training for 3 epochs\n",
"Epoch 1/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1004/1004 [00:43<00:00, 23.24it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average training loss: 0.8538\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 251/251 [00:02<00:00, 107.04it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation loss: 0.5905, Accuracy: 0.8194\n",
"Epoch 2/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1004/1004 [00:40<00:00, 24.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average training loss: 0.4990\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 251/251 [00:02<00:00, 110.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation loss: 0.5298, Accuracy: 0.8411\n",
"Epoch 3/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1004/1004 [00:41<00:00, 24.16it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average training loss: 0.3373\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 251/251 [00:02<00:00, 108.09it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation loss: 0.5763, Accuracy: 0.8314\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"train(model, train_loader, val_loader, optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "a1956c49-2736-43ed-bf59-db9b421594b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted language for the name \"John\" is: ['English']\n",
"['English']\n"
]
}
],
"source": [
"def predict_language(name):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" encoding = tokenizer(name, return_tensors='pt', padding=True, truncation=True)\n",
" outputs = model(**encoding)\n",
" logits = outputs.logits\n",
" prediction = torch.argmax(logits, dim=-1).item()\n",
" language = label_encoder.inverse_transform([prediction])\n",
" return language\n",
"\n",
"# Test the model with a sample name\n",
"sample_name = \"John\"\n",
"predicted_language = predict_language(sample_name)\n",
"print(f'The predicted language for the name \"{sample_name}\" is: {predicted_language}')\n",
"print(predicted_language)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "fc4ae875-5948-4d9a-b615-5f12b65713db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The predicted language for the name \"Maria\" is: ['Italian']\n"
]
}
],
"source": [
"# Test the model with a different name\n",
"sample_name = \"Maria\"\n",
"predicted_language = predict_language(sample_name)\n",
"print(f'The predicted language for the name \"{sample_name}\" is: {predicted_language}')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "9ac1712e-c8c2-4812-b46c-d5f7aed6124f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['Russian'], dtype=object)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_language(\"Misha\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9b6fb2d-aab9-48b2-9021-c2be7131fd65",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment