Created
April 10, 2023 05:07
-
-
Save snakers4/89eab250c404d71e82a913b0b751488c to your computer and use it in GitHub Desktop.
test2.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"metadata": { | |
"id": "1XEMm5oo36Sm" | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import time\n", | |
"import glob\n", | |
"import numba\n", | |
"import random\n", | |
"import itertools\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"\n", | |
"from tqdm import tqdm\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from torch.nn.utils.rnn import pad_sequence\n", | |
"from torch.utils.data import Dataset, DataLoader" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": { | |
"id": "a6-yAg4536Sp" | |
}, | |
"outputs": [], | |
"source": [ | |
"RANDOM_SEED = 42" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"data = pd.read_csv('/content/train.csv')" | |
], | |
"metadata": { | |
"id": "U6Vg9a-0KNqY" | |
}, | |
"execution_count": 61, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"data" | |
], | |
"metadata": { | |
"id": "YAlwu5KAKUh4", | |
"outputId": "7c4d63f3-7130-4014-8baf-2e3dbdff67d1", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 423 | |
} | |
}, | |
"execution_count": 62, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" id word stress num_syllables lemma\n", | |
"0 0 румяной 2 3 румяный\n", | |
"1 1 цифрами 1 3 цифра\n", | |
"2 2 слугами 1 3 слуга\n", | |
"3 3 выбирает 3 4 выбирать\n", | |
"4 4 управдом 3 3 управдом\n", | |
"... ... ... ... ... ...\n", | |
"63433 63433 экзамена 2 4 экзамен\n", | |
"63434 63434 культурой 2 3 культура\n", | |
"63435 63435 объемной 2 3 объемный\n", | |
"63436 63436 участком 2 3 участок\n", | |
"63437 63437 ташкента 2 3 ташкент\n", | |
"\n", | |
"[63438 rows x 5 columns]" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-edc4ed79-b743-47a8-a595-ff6279343978\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>word</th>\n", | |
" <th>stress</th>\n", | |
" <th>num_syllables</th>\n", | |
" <th>lemma</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>румяной</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>румяный</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>цифрами</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>цифра</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>2</td>\n", | |
" <td>слугами</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>слуга</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3</td>\n", | |
" <td>выбирает</td>\n", | |
" <td>3</td>\n", | |
" <td>4</td>\n", | |
" <td>выбирать</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>4</td>\n", | |
" <td>управдом</td>\n", | |
" <td>3</td>\n", | |
" <td>3</td>\n", | |
" <td>управдом</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>63433</th>\n", | |
" <td>63433</td>\n", | |
" <td>экзамена</td>\n", | |
" <td>2</td>\n", | |
" <td>4</td>\n", | |
" <td>экзамен</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>63434</th>\n", | |
" <td>63434</td>\n", | |
" <td>культурой</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>культура</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>63435</th>\n", | |
" <td>63435</td>\n", | |
" <td>объемной</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>объемный</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>63436</th>\n", | |
" <td>63436</td>\n", | |
" <td>участком</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>участок</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>63437</th>\n", | |
" <td>63437</td>\n", | |
" <td>ташкента</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>ташкент</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>63438 rows × 5 columns</p>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-edc4ed79-b743-47a8-a595-ff6279343978')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-edc4ed79-b743-47a8-a595-ff6279343978 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-edc4ed79-b743-47a8-a595-ff6279343978');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 62 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Preprocessing" | |
], | |
"metadata": { | |
"id": "bKe_fTixbfIO" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def stress_pos(word, stress):\n", | |
" res = np.zeros(len(word))\n", | |
" for i in range(len(word)):\n", | |
" if word[i] in ['а', 'о', 'у', 'ы', 'э', 'е', 'ё', 'и', 'ю', 'я']:\n", | |
" if stress == 1:\n", | |
" res[i] = 1\n", | |
" break\n", | |
" else: \n", | |
" stress -= 1 \n", | |
" return res" | |
], | |
"metadata": { | |
"id": "LUaoofm8KaUS" | |
}, | |
"execution_count": 63, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%time\n", | |
"\n", | |
"data['word_stress_pos'] = data.apply(lambda x: stress_pos(x.word, x.stress), axis=1)" | |
], | |
"metadata": { | |
"id": "MCL0OC5AQ10d", | |
"outputId": "3ea44f1a-50ef-448c-edad-2a97877665c3", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"execution_count": 64, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 1.19 s, sys: 10.7 ms, total: 1.2 s\n", | |
"Wall time: 1.2 s\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": { | |
"id": "QMJeYC9636St", | |
"outputId": "6c803ce0-0227-4003-87fa-e93738697237", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 37.9 ms, sys: 26.2 ms, total: 64.2 ms\n", | |
"Wall time: 65.6 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"data['word_list'] = data.word.map(list)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": { | |
"id": "FV6qIcC336Su", | |
"outputId": "00d1f8c4-41a2-4c90-ab61-d599dd8fcdf5", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" id word stress num_syllables lemma \\\n", | |
"49671 49671 священником 2 4 священник \n", | |
"50707 50707 игрушечный 2 4 игрушечный \n", | |
"29760 29760 полмиллиарда 4 5 полмиллиард \n", | |
"61955 61955 байгора 1 3 байгора \n", | |
"28681 28681 цепочки 2 3 цепочка \n", | |
"\n", | |
" word_stress_pos \\\n", | |
"49671 [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ... \n", | |
"50707 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", | |
"29760 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ... \n", | |
"61955 [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] \n", | |
"28681 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", | |
"\n", | |
" word_list \n", | |
"49671 [с, в, я, щ, е, н, н, и, к, о, м] \n", | |
"50707 [и, г, р, у, ш, е, ч, н, ы, й] \n", | |
"29760 [п, о, л, м, и, л, л, и, а, р, д, а] \n", | |
"61955 [б, а, й, г, о, р, а] \n", | |
"28681 [ц, е, п, о, ч, к, и] " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-cf066de4-3646-4f64-afb8-4077910571cf\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>word</th>\n", | |
" <th>stress</th>\n", | |
" <th>num_syllables</th>\n", | |
" <th>lemma</th>\n", | |
" <th>word_stress_pos</th>\n", | |
" <th>word_list</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>49671</th>\n", | |
" <td>49671</td>\n", | |
" <td>священником</td>\n", | |
" <td>2</td>\n", | |
" <td>4</td>\n", | |
" <td>священник</td>\n", | |
" <td>[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n", | |
" <td>[с, в, я, щ, е, н, н, и, к, о, м]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50707</th>\n", | |
" <td>50707</td>\n", | |
" <td>игрушечный</td>\n", | |
" <td>2</td>\n", | |
" <td>4</td>\n", | |
" <td>игрушечный</td>\n", | |
" <td>[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n", | |
" <td>[и, г, р, у, ш, е, ч, н, ы, й]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29760</th>\n", | |
" <td>29760</td>\n", | |
" <td>полмиллиарда</td>\n", | |
" <td>4</td>\n", | |
" <td>5</td>\n", | |
" <td>полмиллиард</td>\n", | |
" <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ...</td>\n", | |
" <td>[п, о, л, м, и, л, л, и, а, р, д, а]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>61955</th>\n", | |
" <td>61955</td>\n", | |
" <td>байгора</td>\n", | |
" <td>1</td>\n", | |
" <td>3</td>\n", | |
" <td>байгора</td>\n", | |
" <td>[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]</td>\n", | |
" <td>[б, а, й, г, о, р, а]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>28681</th>\n", | |
" <td>28681</td>\n", | |
" <td>цепочки</td>\n", | |
" <td>2</td>\n", | |
" <td>3</td>\n", | |
" <td>цепочка</td>\n", | |
" <td>[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]</td>\n", | |
" <td>[ц, е, п, о, ч, к, и]</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cf066de4-3646-4f64-afb8-4077910571cf')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-cf066de4-3646-4f64-afb8-4077910571cf button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-cf066de4-3646-4f64-afb8-4077910571cf');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 66 | |
} | |
], | |
"source": [ | |
"data.sample(5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": { | |
"id": "q5LCEprm36Su", | |
"outputId": "6f278542-1266-4a19-e0a6-412a62953379", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"20" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 67 | |
} | |
], | |
"source": [ | |
"max_sequence_len = np.max(data.word.str.len())\n", | |
"max_sequence_len" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"metadata": { | |
"id": "T1ln0-f436Su" | |
}, | |
"outputs": [], | |
"source": [ | |
"def flatten(array):\n", | |
" for item in array:\n", | |
" if isinstance(item, list):\n", | |
" yield from flatten(item)\n", | |
" else:\n", | |
" yield item\n", | |
"\n", | |
"\n", | |
"class SequenceTokenizer:\n", | |
" \n", | |
" def __init__(self):\n", | |
" self.word2index = {}\n", | |
" self.index2word = {}\n", | |
" self.oov_token ='<UNK>'\n", | |
" self.oov_token_index = 0\n", | |
" \n", | |
" def fit(self, sequence):\n", | |
" self.index2word = dict(enumerate([self.oov_token] + sorted(set(flatten(sequence))), 1))\n", | |
" self.word2index = {v:k for k,v in self.index2word.items()}\n", | |
" self.oov_token_index = self.word2index.get(self.oov_token)\n", | |
" return self\n", | |
" \n", | |
" def transform(self, X):\n", | |
" res = []\n", | |
" for line in X:\n", | |
" res.append([self.word2index.get(item, self.oov_token_index) for item in line])\n", | |
" return res" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": { | |
"id": "h4Rl1qJd36Sv" | |
}, | |
"outputs": [], | |
"source": [ | |
"tokenizer = SequenceTokenizer()\n", | |
"tokenizer.fit(data.word_list);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": { | |
"id": "xj1bSdhr36Sv", | |
"outputId": "08ef9c2b-69dc-4c1e-b0a2-311f2709fb0b", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'<UNK>': 1,\n", | |
" 'а': 2,\n", | |
" 'б': 3,\n", | |
" 'в': 4,\n", | |
" 'г': 5,\n", | |
" 'д': 6,\n", | |
" 'е': 7,\n", | |
" 'ж': 8,\n", | |
" 'з': 9,\n", | |
" 'и': 10,\n", | |
" 'й': 11,\n", | |
" 'к': 12,\n", | |
" 'л': 13,\n", | |
" 'м': 14,\n", | |
" 'н': 15,\n", | |
" 'о': 16,\n", | |
" 'п': 17,\n", | |
" 'р': 18,\n", | |
" 'с': 19,\n", | |
" 'т': 20,\n", | |
" 'у': 21,\n", | |
" 'ф': 22,\n", | |
" 'х': 23,\n", | |
" 'ц': 24,\n", | |
" 'ч': 25,\n", | |
" 'ш': 26,\n", | |
" 'щ': 27,\n", | |
" 'ъ': 28,\n", | |
" 'ы': 29,\n", | |
" 'ь': 30,\n", | |
" 'э': 31,\n", | |
" 'ю': 32,\n", | |
" 'я': 33,\n", | |
" 'ё': 34}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 70 | |
} | |
], | |
"source": [ | |
"tokenizer.word2index" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"metadata": { | |
"id": "46RLP83Z36Sv" | |
}, | |
"outputs": [], | |
"source": [ | |
"X = tokenizer.transform(data.word_list)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": { | |
"id": "yInv9QA636Sv" | |
}, | |
"outputs": [], | |
"source": [ | |
"def pad_sequence(lst, max_seq=max_sequence_len):\n", | |
" if isinstance(lst[0], list):\n", | |
" return np.array([i + [0]*(max_seq-len(i)) for i in lst])\n", | |
" else:\n", | |
" lst + [0]*(max_seq-len(lst))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": { | |
"id": "YmDQ0jGP36Sw", | |
"outputId": "b191b59e-76bd-48b1-a2a9-e693a8379adc", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"CPU times: user 129 ms, sys: 5.91 ms, total: 134 ms\n", | |
"Wall time: 136 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"input_seq = pad_sequence(X)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": { | |
"id": "vdMK4Yrd36Sw", | |
"outputId": "18424cdb-f387-4380-a769-978e91d25b30", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"(63438, 20)\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[18, 21, 14, ..., 0, 0, 0],\n", | |
" [24, 10, 22, ..., 0, 0, 0],\n", | |
" [19, 13, 21, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [16, 3, 28, ..., 0, 0, 0],\n", | |
" [21, 25, 2, ..., 0, 0, 0],\n", | |
" [20, 2, 26, ..., 0, 0, 0]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 74 | |
} | |
], | |
"source": [ | |
"print(input_seq.shape)\n", | |
"input_seq" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": { | |
"id": "x5ATpkXb36Sx" | |
}, | |
"outputs": [], | |
"source": [ | |
"y = data.word_stress_pos.values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": { | |
"id": "izwtrpIy36Sx" | |
}, | |
"outputs": [], | |
"source": [ | |
"output_seq = zip(*itertools.zip_longest(*y, fillvalue=0))\n", | |
"output_seq = list(map(list, output_seq))\n", | |
"output_seq = np.array(output_seq).astype(int)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": { | |
"id": "bqjHJ85O36Sx", | |
"outputId": "283c2a17-e121-4fd5-955c-bf17e1c04906", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"(63438, 20)\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0, 0, 0, ..., 0, 0, 0],\n", | |
" [0, 1, 0, ..., 0, 0, 0],\n", | |
" [0, 0, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [0, 0, 0, ..., 0, 0, 0],\n", | |
" [0, 0, 1, ..., 0, 0, 0],\n", | |
" [0, 0, 0, ..., 0, 0, 0]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 77 | |
} | |
], | |
"source": [ | |
"print(output_seq.shape)\n", | |
"output_seq" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": { | |
"id": "xJaSZ-mf36Sx" | |
}, | |
"outputs": [], | |
"source": [ | |
"(input_seq_train, input_seq_val, \n", | |
" output_seq_train, output_seq_val) = train_test_split(input_seq, \n", | |
" output_seq, \n", | |
" test_size=0.5, \n", | |
" random_state=RANDOM_SEED)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": { | |
"id": "ezujxONV36Sy" | |
}, | |
"outputs": [], | |
"source": [ | |
"input_seq_train = torch.tensor(input_seq_train, dtype=torch.long).cuda()\n", | |
"input_seq_val = torch.tensor(input_seq_val, dtype=torch.long).cuda()\n", | |
"output_seq_train = torch.tensor(output_seq_train, dtype=torch.float).cuda()\n", | |
"output_seq_val = torch.tensor(output_seq_val, dtype=torch.float).cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"metadata": { | |
"id": "gRsPsPpj36Sy" | |
}, | |
"outputs": [], | |
"source": [ | |
"class MyDataset(Dataset):\n", | |
" def __init__(self, dataset):\n", | |
" self.dataset = dataset\n", | |
" \n", | |
" def __getitem__(self, index):\n", | |
" data,target = self.dataset[index]\n", | |
" return data, target, index\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.dataset)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Model" | |
], | |
"metadata": { | |
"id": "92jorRHXcGmz" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 81, | |
"metadata": { | |
"id": "UZ5BFgo736Sy" | |
}, | |
"outputs": [], | |
"source": [ | |
"class LSTM_model(nn.Module):\n", | |
"\n", | |
" def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size):\n", | |
" super(LSTM_model, self).__init__()\n", | |
" self.hidden_dim = hidden_dim\n", | |
"\n", | |
" self.embeddings = nn.Embedding(vocab_size, embedding_dim)\n", | |
"\n", | |
" self.lstm = nn.LSTM(input_size=self.embeddings.embedding_dim,\n", | |
" hidden_size=hidden_dim,\n", | |
" num_layers=3,\n", | |
" batch_first=True,\n", | |
" bidirectional=True,\n", | |
" dropout = 0.05)\n", | |
" self.linear = nn.Linear(self.hidden_dim * 8 , 64)\n", | |
" self.batch_norm = nn.BatchNorm1d(self.hidden_dim * 8, affine=False)\n", | |
" self.relu = nn.ReLU()\n", | |
" self.dropout = nn.Dropout(0.1)\n", | |
" self.out = nn.Linear(64, target_size)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" h_embeddings = self.embeddings(x)\n", | |
" \n", | |
" h_lstm, _ = self.lstm(h_embeddings)\n", | |
" d_1 = h_lstm[:,0,:]\n", | |
" d_2 = h_lstm[:,h_lstm.shape[1]//4,:]\n", | |
" d_3 = h_lstm[:,h_lstm.shape[1]*3//4,:]\n", | |
" d_4 = h_lstm[:,-1,:]\n", | |
" x = torch.cat((d_1, d_2, d_3, d_4), 1)\n", | |
" x = self.batch_norm(x)\n", | |
" x = self.linear(x)\n", | |
" x = self.relu(x)\n", | |
" x = self.dropout(x)\n", | |
" x = self.out(x)\n", | |
" y = nn.functional.softmax(x, dim=1)\n", | |
" return y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": { | |
"id": "uA2SJd7736Sy", | |
"outputId": "bae9801c-448d-4d31-b0d2-48be63acf28b", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"LSTM_model(\n", | |
" (embeddings): Embedding(35, 64)\n", | |
" (lstm): LSTM(64, 64, num_layers=3, batch_first=True, dropout=0.05, bidirectional=True)\n", | |
" (linear): Linear(in_features=512, out_features=64, bias=True)\n", | |
" (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n", | |
" (relu): ReLU()\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out): Linear(in_features=64, out_features=20, bias=True)\n", | |
")" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 82 | |
} | |
], | |
"source": [ | |
"model = LSTM_model(embedding_dim=64, \n", | |
" hidden_dim=64, \n", | |
" vocab_size=len(tokenizer.word2index) + 1, \n", | |
" target_size=max_sequence_len)\n", | |
"model.cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": { | |
"id": "y6hRGxF836Sz" | |
}, | |
"outputs": [], | |
"source": [ | |
"loss_function = nn.BCEWithLogitsLoss()\n", | |
"optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": { | |
"id": "NZD-Vfjh36Sz" | |
}, | |
"outputs": [], | |
"source": [ | |
"BATCH_SIZE = 256 * 2\n", | |
"\n", | |
"train = MyDataset(torch.utils.data.TensorDataset(input_seq_train, output_seq_train))\n", | |
"valid = MyDataset(torch.utils.data.TensorDataset(input_seq_val, output_seq_val))\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)\n", | |
"valid_loader = torch.utils.data.DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "jDEMfycQ36Sz" | |
}, | |
"source": [ | |
"### K-Fold training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 85, | |
"metadata": { | |
"id": "xIltVQnU36S0" | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import KFold" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 86, | |
"metadata": { | |
"id": "fN3FS28a36S1" | |
}, | |
"outputs": [], | |
"source": [ | |
"n_folds = 5" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 87, | |
"metadata": { | |
"id": "p3-X0BHZ36S1" | |
}, | |
"outputs": [], | |
"source": [ | |
"kf = KFold(n_splits=n_folds, shuffle=True, random_state=RANDOM_SEED)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": { | |
"id": "m2x5XeCL36S1", | |
"outputId": "818fb1f8-c1c5-4293-8065-a42249b8bbd4", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"(50750,) (12688,)\n", | |
"(50750,) (12688,)\n", | |
"(50750,) (12688,)\n", | |
"(50751,) (12687,)\n", | |
"(50751,) (12687,)\n" | |
] | |
} | |
], | |
"source": [ | |
"for train_index, test_index in kf.split(X=input_seq, y=output_seq):\n", | |
" print(train_index.shape, test_index.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": { | |
"id": "zonYLRzp36S1", | |
"outputId": "f0aed232-79a7-4cc7-93cc-a40c0402421a", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[18, 21, 14, ..., 0, 0, 0],\n", | |
" [24, 10, 22, ..., 0, 0, 0],\n", | |
" [19, 13, 21, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [16, 3, 28, ..., 0, 0, 0],\n", | |
" [21, 25, 2, ..., 0, 0, 0],\n", | |
" [20, 2, 26, ..., 0, 0, 0]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 89 | |
} | |
], | |
"source": [ | |
"input_seq" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "m-qOODva36S2" | |
}, | |
"source": [ | |
"### Training loop" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"metadata": { | |
"id": "G41070T636S2", | |
"outputId": "88019e40-3214-4b5e-8db5-5285d059ef0a", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Epoch [1/50] progress = 98% \t loss=0.7073 \t acc=32.47% \n", | |
"Epoch [1/50] results:\t\t loss=0.7073\t acc=32.47%\t val_loss=0.7030\t val_acc=44.64%\t time=1.42s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [2/50] progress = 98% \t loss=0.6954 \t acc=57.92% \n", | |
"Epoch [2/50] results:\t\t loss=0.6954\t acc=57.92%\t val_loss=0.6914\t val_acc=65.59%\t time=1.47s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [3/50] progress = 98% \t loss=0.6900 \t acc=68.33% \n", | |
"Epoch [3/50] results:\t\t loss=0.6900\t acc=68.33%\t val_loss=0.6890\t val_acc=70.17%\t time=1.35s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [4/50] progress = 98% \t loss=0.6882 \t acc=71.96% \n", | |
"Epoch [4/50] results:\t\t loss=0.6882\t acc=71.96%\t val_loss=0.6883\t val_acc=71.45%\t time=1.68s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [5/50] progress = 98% \t loss=0.6875 \t acc=73.19% \n", | |
"Epoch [5/50] results:\t\t loss=0.6875\t acc=73.19%\t val_loss=0.6876\t val_acc=72.75%\t time=1.56s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [6/50] progress = 98% \t loss=0.6866 \t acc=75.19% \n", | |
"Epoch [6/50] results:\t\t loss=0.6866\t acc=75.19%\t val_loss=0.6870\t val_acc=74.08%\t time=1.32s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [7/50] progress = 98% \t loss=0.6860 \t acc=76.11% \n", | |
"Epoch [7/50] results:\t\t loss=0.6860\t acc=76.11%\t val_loss=0.6864\t val_acc=75.45%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [8/50] progress = 98% \t loss=0.6855 \t acc=77.21% \n", | |
"Epoch [8/50] results:\t\t loss=0.6855\t acc=77.21%\t val_loss=0.6861\t val_acc=75.99%\t time=1.47s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [9/50] progress = 98% \t loss=0.6852 \t acc=77.76% \n", | |
"Epoch [9/50] results:\t\t loss=0.6852\t acc=77.76%\t val_loss=0.6857\t val_acc=76.71%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [10/50] progress = 98% \t loss=0.6848 \t acc=78.58% \n", | |
"Epoch [10/50] results:\t\t loss=0.6848\t acc=78.58%\t val_loss=0.6857\t val_acc=76.74%\t time=1.35s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [11/50] progress = 98% \t loss=0.6842 \t acc=79.83% \n", | |
"Epoch [11/50] results:\t\t loss=0.6842\t acc=79.83%\t val_loss=0.6849\t val_acc=78.53%\t time=1.32s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [12/50] progress = 98% \t loss=0.6838 \t acc=80.68% \n", | |
"Epoch [12/50] results:\t\t loss=0.6838\t acc=80.68%\t val_loss=0.6848\t val_acc=78.61%\t time=1.48s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [13/50] progress = 98% \t loss=0.6838 \t acc=80.61% \n", | |
"Epoch [13/50] results:\t\t loss=0.6838\t acc=80.61%\t val_loss=0.6847\t val_acc=78.81%\t time=1.66s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [14/50] progress = 98% \t loss=0.6833 \t acc=81.56% \n", | |
"Epoch [14/50] results:\t\t loss=0.6833\t acc=81.56%\t val_loss=0.6844\t val_acc=79.43%\t time=1.51s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [15/50] progress = 98% \t loss=0.6832 \t acc=81.91% \n", | |
"Epoch [15/50] results:\t\t loss=0.6832\t acc=81.91%\t val_loss=0.6844\t val_acc=79.35%\t time=1.34s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [16/50] progress = 98% \t loss=0.6829 \t acc=82.49% \n", | |
"Epoch [16/50] results:\t\t loss=0.6829\t acc=82.49%\t val_loss=0.6841\t val_acc=79.96%\t time=1.35s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [17/50] progress = 98% \t loss=0.6827 \t acc=82.89% \n", | |
"Epoch [17/50] results:\t\t loss=0.6827\t acc=82.89%\t val_loss=0.6843\t val_acc=79.54%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [18/50] progress = 98% \t loss=0.6825 \t acc=83.25% \n", | |
"Epoch [18/50] results:\t\t loss=0.6825\t acc=83.25%\t val_loss=0.6840\t val_acc=80.11%\t time=1.34s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [19/50] progress = 98% \t loss=0.6823 \t acc=83.72% \n", | |
"Epoch [19/50] results:\t\t loss=0.6823\t acc=83.72%\t val_loss=0.6839\t val_acc=80.42%\t time=1.37s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [20/50] progress = 98% \t loss=0.6822 \t acc=83.85% \n", | |
"Epoch [20/50] results:\t\t loss=0.6822\t acc=83.85%\t val_loss=0.6838\t val_acc=80.65%\t time=1.50s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [21/50] progress = 98% \t loss=0.6820 \t acc=84.19% \n", | |
"Epoch [21/50] results:\t\t loss=0.6820\t acc=84.19%\t val_loss=0.6840\t val_acc=80.10%\t time=1.61s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [22/50] progress = 98% \t loss=0.6819 \t acc=84.51% \n", | |
"Epoch [22/50] results:\t\t loss=0.6819\t acc=84.51%\t val_loss=0.6836\t val_acc=80.95%\t time=1.51s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [23/50] progress = 98% \t loss=0.6817 \t acc=84.92% \n", | |
"Epoch [23/50] results:\t\t loss=0.6817\t acc=84.92%\t val_loss=0.6835\t val_acc=81.17%\t time=1.71s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [24/50] progress = 98% \t loss=0.6815 \t acc=85.27% \n", | |
"Epoch [24/50] results:\t\t loss=0.6815\t acc=85.27%\t val_loss=0.6836\t val_acc=80.90%\t time=2.25s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [25/50] progress = 98% \t loss=0.6815 \t acc=85.42% \n", | |
"Epoch [25/50] results:\t\t loss=0.6815\t acc=85.42%\t val_loss=0.6834\t val_acc=81.36%\t time=1.41s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [26/50] progress = 98% \t loss=0.6814 \t acc=85.62% \n", | |
"Epoch [26/50] results:\t\t loss=0.6814\t acc=85.62%\t val_loss=0.6834\t val_acc=81.44%\t time=1.49s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [27/50] progress = 98% \t loss=0.6812 \t acc=85.88% \n", | |
"Epoch [27/50] results:\t\t loss=0.6812\t acc=85.88%\t val_loss=0.6835\t val_acc=81.20%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [28/50] progress = 98% \t loss=0.6811 \t acc=86.17% \n", | |
"Epoch [28/50] results:\t\t loss=0.6811\t acc=86.17%\t val_loss=0.6832\t val_acc=81.71%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [29/50] progress = 98% \t loss=0.6809 \t acc=86.44% \n", | |
"Epoch [29/50] results:\t\t loss=0.6809\t acc=86.44%\t val_loss=0.6832\t val_acc=81.77%\t time=3.03s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [30/50] progress = 98% \t loss=0.6807 \t acc=86.77% \n", | |
"Epoch [30/50] results:\t\t loss=0.6807\t acc=86.77%\t val_loss=0.6832\t val_acc=81.76%\t time=1.34s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [31/50] progress = 98% \t loss=0.6806 \t acc=87.00% \n", | |
"Epoch [31/50] results:\t\t loss=0.6806\t acc=87.00%\t val_loss=0.6831\t val_acc=82.11%\t time=1.51s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [32/50] progress = 98% \t loss=0.6805 \t acc=87.17% \n", | |
"Epoch [32/50] results:\t\t loss=0.6805\t acc=87.17%\t val_loss=0.6829\t val_acc=82.33%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [33/50] progress = 98% \t loss=0.6805 \t acc=87.32% \n", | |
"Epoch [33/50] results:\t\t loss=0.6805\t acc=87.32%\t val_loss=0.6829\t val_acc=82.48%\t time=1.32s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [34/50] progress = 98% \t loss=0.6803 \t acc=87.62% \n", | |
"Epoch [34/50] results:\t\t loss=0.6803\t acc=87.62%\t val_loss=0.6829\t val_acc=82.31%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [35/50] progress = 98% \t loss=0.6802 \t acc=87.90% \n", | |
"Epoch [35/50] results:\t\t loss=0.6802\t acc=87.90%\t val_loss=0.6828\t val_acc=82.62%\t time=1.49s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [36/50] progress = 98% \t loss=0.6800 \t acc=88.25% \n", | |
"Epoch [36/50] results:\t\t loss=0.6800\t acc=88.25%\t val_loss=0.6827\t val_acc=82.84%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [37/50] progress = 98% \t loss=0.6799 \t acc=88.48% \n", | |
"Epoch [37/50] results:\t\t loss=0.6799\t acc=88.48%\t val_loss=0.6828\t val_acc=82.47%\t time=1.63s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [38/50] progress = 98% \t loss=0.6798 \t acc=88.66% \n", | |
"Epoch [38/50] results:\t\t loss=0.6798\t acc=88.66%\t val_loss=0.6828\t val_acc=82.68%\t time=2.02s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [39/50] progress = 98% \t loss=0.6797 \t acc=88.83% \n", | |
"Epoch [39/50] results:\t\t loss=0.6797\t acc=88.83%\t val_loss=0.6826\t val_acc=83.10%\t time=1.33s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [40/50] progress = 98% \t loss=0.6797 \t acc=89.00% \n", | |
"Epoch [40/50] results:\t\t loss=0.6797\t acc=89.00%\t val_loss=0.6826\t val_acc=82.97%\t time=1.48s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [41/50] progress = 98% \t loss=0.6797 \t acc=88.96% \n", | |
"Epoch [41/50] results:\t\t loss=0.6797\t acc=88.96%\t val_loss=0.6825\t val_acc=83.18%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [42/50] progress = 98% \t loss=0.6795 \t acc=89.27% \n", | |
"Epoch [42/50] results:\t\t loss=0.6795\t acc=89.27%\t val_loss=0.6825\t val_acc=83.26%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [43/50] progress = 98% \t loss=0.6795 \t acc=89.32% \n", | |
"Epoch [43/50] results:\t\t loss=0.6795\t acc=89.32%\t val_loss=0.6825\t val_acc=83.19%\t time=1.32s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [44/50] progress = 98% \t loss=0.6794 \t acc=89.50% \n", | |
"Epoch [44/50] results:\t\t loss=0.6794\t acc=89.50%\t val_loss=0.6824\t val_acc=83.39%\t time=1.49s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [45/50] progress = 98% \t loss=0.6794 \t acc=89.50% \n", | |
"Epoch [45/50] results:\t\t loss=0.6794\t acc=89.50%\t val_loss=0.6824\t val_acc=83.42%\t time=1.49s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [46/50] progress = 98% \t loss=0.6793 \t acc=89.73% \n", | |
"Epoch [46/50] results:\t\t loss=0.6793\t acc=89.73%\t val_loss=0.6824\t val_acc=83.43%\t time=1.65s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [47/50] progress = 98% \t loss=0.6792 \t acc=89.94% \n", | |
"Epoch [47/50] results:\t\t loss=0.6792\t acc=89.94%\t val_loss=0.6824\t val_acc=83.47%\t time=1.48s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [48/50] progress = 98% \t loss=0.6792 \t acc=89.92% \n", | |
"Epoch [48/50] results:\t\t loss=0.6792\t acc=89.92%\t val_loss=0.6823\t val_acc=83.59%\t time=1.32s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [49/50] progress = 98% \t loss=0.6791 \t acc=90.01% \n", | |
"Epoch [49/50] results:\t\t loss=0.6791\t acc=90.01%\t val_loss=0.6825\t val_acc=83.20%\t time=1.31s\n", | |
"------------------------------------------------------------------------------\n", | |
"Epoch [50/50] progress = 98% \t loss=0.6791 \t acc=90.14% \n", | |
"Epoch [50/50] results:\t\t loss=0.6791\t acc=90.14%\t val_loss=0.6823\t val_acc=83.66%\t time=1.52s\n", | |
"------------------------------------------------------------------------------\n" | |
] | |
} | |
], | |
"source": [ | |
"n_epochs = 50\n", | |
"history = {'train': {}, 'val': {}}\n", | |
"teacher_forcing_ratio = 0.5\n", | |
"\n", | |
"for epoch in range(1, n_epochs + 1):\n", | |
" start_time = time.time()\n", | |
" \n", | |
" model.train()\n", | |
"\n", | |
" avg_loss, total_loss, avg_acc, total_acc = 0., 0., 0., 0.\n", | |
" for i, (x_batch, y_batch, index) in enumerate(train_loader):\n", | |
" y_pred = model(x_batch)\n", | |
" loss = loss_function(y_pred, y_batch)\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" equal = torch.eq(torch.argmax(y_pred, axis=1), torch.argmax(y_batch, axis=1))\n", | |
" batch_acc = int(equal.sum(-1)) / y_batch.shape[0]\n", | |
" batch_loss = loss.item()\n", | |
" \n", | |
" total_acc += batch_acc\n", | |
" total_loss += batch_loss\n", | |
" print(f\"\\rEpoch [{epoch}/{n_epochs}] \"\n", | |
" f\" progress = {round(i/len(train_loader)*100)}% \"\n", | |
" f\"\\t loss={total_loss / (i + 1):.4f} \"\n", | |
" f\"\\t acc={total_acc / (i + 1) * 100:.2f}% \", end='')\n", | |
" avg_loss = total_loss / len(train_loader)\n", | |
" avg_acc = total_acc / len(train_loader)\n", | |
" history['train']['loss'] = history.get('train', {}).get('loss', []) + [avg_loss]\n", | |
" history['train']['accuracy'] = history.get('train', {}).get('accuracy', []) + [avg_acc]\n", | |
" \n", | |
" model.eval()\n", | |
"\n", | |
" \n", | |
" avg_val_loss, total_val_loss, avg_val_acc, total_val_acc = 0., 0., 0., 0.\n", | |
" for i, (x_batch, y_batch, index) in enumerate(valid_loader):\n", | |
" y_pred = model(x_batch).detach()\n", | |
" val_loss = loss_function(y_pred, y_batch)\n", | |
" \n", | |
" equal = torch.eq(torch.argmax(y_pred, axis=1), torch.argmax(y_batch, axis=1))\n", | |
" batch_val_acc = int(equal.sum(-1)) / y_batch.shape[0]\n", | |
" batch_val_loss = val_loss.item()\n", | |
" \n", | |
" total_val_acc += batch_val_acc\n", | |
" total_val_loss += batch_val_loss\n", | |
" avg_val_loss = total_val_loss / len(valid_loader)\n", | |
" avg_val_acc = total_val_acc / len(valid_loader)\n", | |
" history['val']['loss'] = history.get('val', {}).get('loss', []) + [avg_val_loss]\n", | |
" history['val']['accuracy'] = history.get('val', {}).get('accuracy', []) + [avg_val_acc]\n", | |
" \n", | |
" elapsed_time = time.time() - start_time \n", | |
" print(f\"\\nEpoch [{epoch}/{n_epochs}] results:\"\n", | |
" f\"\\t\\t loss={avg_loss:.4f}\"\n", | |
" f\"\\t acc={avg_acc * 100:.2f}%\"\n", | |
" f\"\\t val_loss={avg_val_loss:.4f}\"\n", | |
" f\"\\t val_acc={avg_val_acc * 100:.2f}%\"\n", | |
" f\"\\t time={elapsed_time:.2f}s\")\n", | |
" print(\"-\"*78)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"torch.save(model.state_dict(),\"accentor.pt\")" | |
], | |
"metadata": { | |
"id": "t_KzF99MbYZ2" | |
}, | |
"execution_count": 93, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 94, | |
"metadata": { | |
"id": "gLj87-3T36S2", | |
"outputId": "7d666ca6-0865-4342-b10b-9320d92331b4", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 718 | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1200x800 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(12, 8))\n", | |
"\n", | |
"plt.plot(history['train']['loss'])\n", | |
"plt.plot(history['val']['loss'])\n", | |
"plt.title('model loss')\n", | |
"plt.ylabel('val')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.xticks(np.arange(len(history['train']['loss'])), np.arange(1, len(history['train']['loss']) + 1))\n", | |
"plt.legend(['train', 'val'], loc='upper right')\n", | |
"\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"metadata": { | |
"id": "149EM4q-36S2", | |
"outputId": "b11e27f9-7cb1-4690-8468-f41a9dbfe8f4", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 718 | |
} | |
}, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1200x800 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
], | |
"source": [ | |
"plt.figure(figsize=(12, 8))\n", | |
"\n", | |
"plt.plot(history['train']['accuracy'])\n", | |
"plt.plot(history['val']['accuracy'])\n", | |
"plt.title('model accuracy')\n", | |
"plt.ylabel('accuracy')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.xticks(np.arange(len(history['train']['accuracy'])), np.arange(1, len(history['train']['accuracy']) + 1))\n", | |
"plt.legend(['train', 'val'], loc='lower right')\n", | |
"\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5DaCsCk836S3" | |
}, | |
"source": [ | |
"### Predict" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test = pd.read_csv('/content/test.csv')" | |
], | |
"metadata": { | |
"id": "SZuNwt_0VD4H" | |
}, | |
"execution_count": 96, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 423 | |
}, | |
"id": "6ZkZMm41WD6t", | |
"outputId": "6b99b007-964f-4bf6-bd3e-1ea4be497869" | |
}, | |
"execution_count": 97, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" id word num_syllables lemma\n", | |
"0 0 эпилепсия 5 эпилепсия\n", | |
"1 1 относящейся 5 относиться\n", | |
"2 2 размышлениями 6 размышление\n", | |
"3 3 модемы 3 модем\n", | |
"4 4 солнц 1 солнце\n", | |
"... ... ... ... ...\n", | |
"29955 29955 донбасса 3 донбасс\n", | |
"29956 29956 обложка 3 обложка\n", | |
"29957 29957 правителя 4 правитель\n", | |
"29958 29958 шерстяной 3 шерстяной\n", | |
"29959 29959 оптимизации 6 оптимизация\n", | |
"\n", | |
"[29960 rows x 4 columns]" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-298679aa-e94c-4753-88bc-8080a262015b\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>word</th>\n", | |
" <th>num_syllables</th>\n", | |
" <th>lemma</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>эпилепсия</td>\n", | |
" <td>5</td>\n", | |
" <td>эпилепсия</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>относящейся</td>\n", | |
" <td>5</td>\n", | |
" <td>относиться</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>2</td>\n", | |
" <td>размышлениями</td>\n", | |
" <td>6</td>\n", | |
" <td>размышление</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3</td>\n", | |
" <td>модемы</td>\n", | |
" <td>3</td>\n", | |
" <td>модем</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>4</td>\n", | |
" <td>солнц</td>\n", | |
" <td>1</td>\n", | |
" <td>солнце</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29955</th>\n", | |
" <td>29955</td>\n", | |
" <td>донбасса</td>\n", | |
" <td>3</td>\n", | |
" <td>донбасс</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29956</th>\n", | |
" <td>29956</td>\n", | |
" <td>обложка</td>\n", | |
" <td>3</td>\n", | |
" <td>обложка</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29957</th>\n", | |
" <td>29957</td>\n", | |
" <td>правителя</td>\n", | |
" <td>4</td>\n", | |
" <td>правитель</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29958</th>\n", | |
" <td>29958</td>\n", | |
" <td>шерстяной</td>\n", | |
" <td>3</td>\n", | |
" <td>шерстяной</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29959</th>\n", | |
" <td>29959</td>\n", | |
" <td>оптимизации</td>\n", | |
" <td>6</td>\n", | |
" <td>оптимизация</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>29960 rows × 4 columns</p>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-298679aa-e94c-4753-88bc-8080a262015b')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-298679aa-e94c-4753-88bc-8080a262015b button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-298679aa-e94c-4753-88bc-8080a262015b');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 97 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def stress_pred(words):\n", | |
" tokens = pad_sequence(tokenizer.transform(words))\n", | |
" sequences = torch.tensor(tokens, dtype=torch.long).cuda()\n", | |
" preds = model(sequences)\n", | |
" indeces = torch.argmax(preds, axis=1)\n", | |
" indeces = indeces.to('cpu').numpy()\n", | |
" res = []\n", | |
" for i in range(len(indeces)):\n", | |
" pos = indeces[i]\n", | |
" coun = 1\n", | |
" for j in range(pos - 1,-1,-1,):\n", | |
" if words[i][j] in {'а', 'о', 'у', 'ы', 'э', 'е', 'ё', 'и', 'ю', 'я'}:\n", | |
" coun += 1\n", | |
" res.append(coun)\n", | |
" return res" | |
], | |
"metadata": { | |
"id": "xK1ZTbtRd644" | |
}, | |
"execution_count": 98, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"words = test['word'].tolist()" | |
], | |
"metadata": { | |
"id": "ajqABWgCYJ0n" | |
}, | |
"execution_count": 99, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pred_stress = stress_pred(words)" | |
], | |
"metadata": { | |
"id": "BnQIjobDlnoP" | |
}, | |
"execution_count": 100, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test['pred_stress'] = pred_stress" | |
], | |
"metadata": { | |
"id": "1UClWDbuZYWK" | |
}, | |
"execution_count": 101, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test" | |
], | |
"metadata": { | |
"id": "eaHU-R2LmII3", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 423 | |
}, | |
"outputId": "5d143eb5-f0c3-41c7-d982-ece4bf5596f8" | |
}, | |
"execution_count": 102, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" id word num_syllables lemma pred_stress\n", | |
"0 0 эпилепсия 5 эпилепсия 3\n", | |
"1 1 относящейся 5 относиться 3\n", | |
"2 2 размышлениями 6 размышление 3\n", | |
"3 3 модемы 3 модем 2\n", | |
"4 4 солнц 1 солнце 1\n", | |
"... ... ... ... ... ...\n", | |
"29955 29955 донбасса 3 донбасс 2\n", | |
"29956 29956 обложка 3 обложка 2\n", | |
"29957 29957 правителя 4 правитель 2\n", | |
"29958 29958 шерстяной 3 шерстяной 2\n", | |
"29959 29959 оптимизации 6 оптимизация 4\n", | |
"\n", | |
"[29960 rows x 5 columns]" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-58d65a88-788e-4cef-aa6e-c99c96833dc8\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>word</th>\n", | |
" <th>num_syllables</th>\n", | |
" <th>lemma</th>\n", | |
" <th>pred_stress</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>0</td>\n", | |
" <td>эпилепсия</td>\n", | |
" <td>5</td>\n", | |
" <td>эпилепсия</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>1</td>\n", | |
" <td>относящейся</td>\n", | |
" <td>5</td>\n", | |
" <td>относиться</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>2</td>\n", | |
" <td>размышлениями</td>\n", | |
" <td>6</td>\n", | |
" <td>размышление</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>3</td>\n", | |
" <td>модемы</td>\n", | |
" <td>3</td>\n", | |
" <td>модем</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>4</td>\n", | |
" <td>солнц</td>\n", | |
" <td>1</td>\n", | |
" <td>солнце</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29955</th>\n", | |
" <td>29955</td>\n", | |
" <td>донбасса</td>\n", | |
" <td>3</td>\n", | |
" <td>донбасс</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29956</th>\n", | |
" <td>29956</td>\n", | |
" <td>обложка</td>\n", | |
" <td>3</td>\n", | |
" <td>обложка</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29957</th>\n", | |
" <td>29957</td>\n", | |
" <td>правителя</td>\n", | |
" <td>4</td>\n", | |
" <td>правитель</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29958</th>\n", | |
" <td>29958</td>\n", | |
" <td>шерстяной</td>\n", | |
" <td>3</td>\n", | |
" <td>шерстяной</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29959</th>\n", | |
" <td>29959</td>\n", | |
" <td>оптимизации</td>\n", | |
" <td>6</td>\n", | |
" <td>оптимизация</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>29960 rows × 5 columns</p>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-58d65a88-788e-4cef-aa6e-c99c96833dc8')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-58d65a88-788e-4cef-aa6e-c99c96833dc8 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-58d65a88-788e-4cef-aa6e-c99c96833dc8');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 102 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"test.to_csv('pred.csv')" | |
], | |
"metadata": { | |
"id": "p9M0Oei0Z16J" | |
}, | |
"execution_count": 103, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "ujctqO-d36S4" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "--8Nn7QD36S4" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "Zy4DN7mL36S4" | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "tqzDD25l36S4" | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.8.10 64-bit", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" | |
} | |
}, | |
"colab": { | |
"provenance": [] | |
}, | |
"accelerator": "GPU", | |
"gpuClass": "standard" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment