Created
September 21, 2021 12:26
-
-
Save MachineLearningIsEasy/cf16d61e42c9f6a007f1234ecfe0dc0a to your computer and use it in GitHub Desktop.
NER with tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "d9389ddb-3d05-4ddd-bf6c-0ca640ff0557", | |
"metadata": {}, | |
"source": [ | |
"\n", | |
"\n", | |
"[перейти](https://www.bigdataschool.ru/)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "582cf530-0903-40d2-a10e-2491625a5c14", | |
"metadata": {}, | |
"source": [ | |
"Датасет:\n", | |
"https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus?select=ner_dataset.csv" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a2cc7955-9630-4c5a-9ea0-4becefd879a2", | |
"metadata": {}, | |
"source": [ | |
"### Загружаем данные" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "72041984-f817-4e36-adcd-71ee22b3c3a1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Sentence #</th>\n", | |
" <th>Word</th>\n", | |
" <th>POS</th>\n", | |
" <th>Tag</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Sentence: 1</td>\n", | |
" <td>Thousands</td>\n", | |
" <td>NNS</td>\n", | |
" <td>O</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>NaN</td>\n", | |
" <td>of</td>\n", | |
" <td>IN</td>\n", | |
" <td>O</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>NaN</td>\n", | |
" <td>demonstrators</td>\n", | |
" <td>NNS</td>\n", | |
" <td>O</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>NaN</td>\n", | |
" <td>have</td>\n", | |
" <td>VBP</td>\n", | |
" <td>O</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>NaN</td>\n", | |
" <td>marched</td>\n", | |
" <td>VBN</td>\n", | |
" <td>O</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Sentence # Word POS Tag\n", | |
"0 Sentence: 1 Thousands NNS O\n", | |
"1 NaN of IN O\n", | |
"2 NaN demonstrators NNS O\n", | |
"3 NaN have VBP O\n", | |
"4 NaN marched VBN O" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"data = pd.read_csv('data/ner_dataset.csv', encoding= 'unicode_escape')\n", | |
"data.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "fb2c35f4-f10e-4131-86c9-b5d218d89265", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"NN 145807\n", | |
"NNP 131426\n", | |
"IN 120996\n", | |
"DT 98454\n", | |
"JJ 78412\n", | |
"NNS 75840\n", | |
". 47831\n", | |
"VBD 39379\n", | |
", 32757\n", | |
"VBN 32328\n", | |
"VBZ 24960\n", | |
"CD 24695\n", | |
"VB 24211\n", | |
"CC 23716\n", | |
"TO 23061\n", | |
"RB 20252\n", | |
"VBG 19125\n", | |
"VBP 16158\n", | |
"PRP 13318\n", | |
"POS 11257\n", | |
"PRP$ 8655\n", | |
"MD 6973\n", | |
"`` 3728\n", | |
"WDT 3698\n", | |
"JJS 3034\n", | |
"JJR 2967\n", | |
"WP 2542\n", | |
"NNPS 2521\n", | |
"RP 2490\n", | |
"WRB 2184\n", | |
"$ 1149\n", | |
"RBR 1055\n", | |
": 795\n", | |
"RRB 679\n", | |
"LRB 678\n", | |
"EX 663\n", | |
"RBS 296\n", | |
"; 214\n", | |
"PDT 147\n", | |
"WP$ 99\n", | |
"UH 24\n", | |
"FW 1\n", | |
"Name: POS, dtype: int64" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data.POS.value_counts()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "def001f7-77b8-4f18-8223-768f56da583e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"O 887908\n", | |
"B-geo 37644\n", | |
"B-tim 20333\n", | |
"B-org 20143\n", | |
"I-per 17251\n", | |
"B-per 16990\n", | |
"I-org 16784\n", | |
"B-gpe 15870\n", | |
"I-geo 7414\n", | |
"I-tim 6528\n", | |
"B-art 402\n", | |
"B-eve 308\n", | |
"I-art 297\n", | |
"I-eve 253\n", | |
"B-nat 201\n", | |
"I-gpe 198\n", | |
"I-nat 51\n", | |
"Name: Tag, dtype: int64" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data.Tag.value_counts()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "a3b35c9d-bff4-4981-8000-214c0bdc5e7e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8467758624800324" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"len(data[data['Tag']=='O'])/len(data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4c276f13-c3d9-47bb-b581-7216e41b2dc3", | |
"metadata": {}, | |
"source": [ | |
"### Конвертируем данные в числовой вид" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "549d19c6-34c9-4d52-a818-912e10d78e0a", | |
"metadata": {}, | |
"source": [ | |
"- {token} -> {token id}: построим embeddings\n", | |
"- {tag} -> {tag id}." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "8f21b18c-933a-44f6-a753-8dbed62426c5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from itertools import chain\n", | |
"def get_dict_map(data, token_or_tag):\n", | |
" tok2idx = {}\n", | |
" idx2tok = {}\n", | |
" \n", | |
" if token_or_tag == 'token':\n", | |
" vocab = list(set(data['Word'].to_list()))\n", | |
" else:\n", | |
" vocab = list(set(data['Tag'].to_list()))\n", | |
" \n", | |
" idx2tok = {idx:tok for idx, tok in enumerate(vocab)}\n", | |
" tok2idx = {tok:idx for idx, tok in enumerate(vocab)}\n", | |
" return tok2idx, idx2tok\n", | |
"\n", | |
"\n", | |
"token2idx, idx2token = get_dict_map(data, 'token')\n", | |
"tag2idx, idx2tag = get_dict_map(data, 'tag')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "5ded9e43-6135-4369-af85-8fbdf0cc5f68", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{0: 'I-tim',\n", | |
" 1: 'B-nat',\n", | |
" 2: 'I-org',\n", | |
" 3: 'I-art',\n", | |
" 4: 'I-geo',\n", | |
" 5: 'B-art',\n", | |
" 6: 'B-eve',\n", | |
" 7: 'B-tim',\n", | |
" 8: 'I-gpe',\n", | |
" 9: 'I-eve',\n", | |
" 10: 'I-nat',\n", | |
" 11: 'I-per',\n", | |
" 12: 'B-per',\n", | |
" 13: 'B-org',\n", | |
" 14: 'O',\n", | |
" 15: 'B-geo',\n", | |
" 16: 'B-gpe'}" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# token2idx\n", | |
"tag2idx\n", | |
"idx2tag" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "6e830091-f664-4f0c-816c-6c8f1d03f067", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Sentence #</th>\n", | |
" <th>Word</th>\n", | |
" <th>POS</th>\n", | |
" <th>Tag</th>\n", | |
" <th>Word_idx</th>\n", | |
" <th>Tag_idx</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Sentence: 1</td>\n", | |
" <td>Thousands</td>\n", | |
" <td>NNS</td>\n", | |
" <td>O</td>\n", | |
" <td>5292</td>\n", | |
" <td>14</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>NaN</td>\n", | |
" <td>of</td>\n", | |
" <td>IN</td>\n", | |
" <td>O</td>\n", | |
" <td>16691</td>\n", | |
" <td>14</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>NaN</td>\n", | |
" <td>demonstrators</td>\n", | |
" <td>NNS</td>\n", | |
" <td>O</td>\n", | |
" <td>10027</td>\n", | |
" <td>14</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>NaN</td>\n", | |
" <td>have</td>\n", | |
" <td>VBP</td>\n", | |
" <td>O</td>\n", | |
" <td>22673</td>\n", | |
" <td>14</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>NaN</td>\n", | |
" <td>marched</td>\n", | |
" <td>VBN</td>\n", | |
" <td>O</td>\n", | |
" <td>27636</td>\n", | |
" <td>14</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Sentence # Word POS Tag Word_idx Tag_idx\n", | |
"0 Sentence: 1 Thousands NNS O 5292 14\n", | |
"1 NaN of IN O 16691 14\n", | |
"2 NaN demonstrators NNS O 10027 14\n", | |
"3 NaN have VBP O 22673 14\n", | |
"4 NaN marched VBN O 27636 14" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data['Word_idx'] = data['Word'].map(token2idx)\n", | |
"data['Tag_idx'] = data['Tag'].map(tag2idx)\n", | |
"data.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f57317be-466d-403e-8b38-7abc8ad40fb8", | |
"metadata": {}, | |
"source": [ | |
"### Выполняем трансформацию датасета для получения данных в виде строк" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "7ac99c6b-0eb5-4420-9185-94df64d0a240", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/var/folders/0f/0yc24w1x5fg93fk9thd7tqt00000gn/T/ipykernel_16581/3574260138.py:4: FutureWarning: Indexing with multiple keys (implicitly converted to a tuple of keys) will be deprecated, use a list instead.\n", | |
" data_group = data_fillna.groupby(\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Sentence #</th>\n", | |
" <th>Word</th>\n", | |
" <th>POS</th>\n", | |
" <th>Tag</th>\n", | |
" <th>Word_idx</th>\n", | |
" <th>Tag_idx</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Sentence: 1</td>\n", | |
" <td>[Thousands, of, demonstrators, have, marched, ...</td>\n", | |
" <td>[NNS, IN, NNS, VBP, VBN, IN, NNP, TO, VB, DT, ...</td>\n", | |
" <td>[O, O, O, O, O, O, B-geo, O, O, O, O, O, B-geo...</td>\n", | |
" <td>[5292, 16691, 10027, 22673, 27636, 10500, 1569...</td>\n", | |
" <td>[14, 14, 14, 14, 14, 14, 15, 14, 14, 14, 14, 1...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Sentence: 10</td>\n", | |
" <td>[Iranian, officials, say, they, expect, to, ge...</td>\n", | |
" <td>[JJ, NNS, VBP, PRP, VBP, TO, VB, NN, TO, JJ, J...</td>\n", | |
" <td>[B-gpe, O, O, O, O, O, O, O, O, O, O, O, O, O,...</td>\n", | |
" <td>[19259, 15982, 10924, 1259, 33564, 25715, 2258...</td>\n", | |
" <td>[16, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 1...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>Sentence: 100</td>\n", | |
" <td>[Helicopter, gunships, Saturday, pounded, mili...</td>\n", | |
" <td>[NN, NNS, NNP, VBD, JJ, NNS, IN, DT, NNP, JJ, ...</td>\n", | |
" <td>[O, O, B-tim, O, O, O, O, O, B-geo, O, O, O, O...</td>\n", | |
" <td>[6517, 20597, 9446, 30647, 9086, 33177, 25074,...</td>\n", | |
" <td>[14, 14, 7, 14, 14, 14, 14, 14, 15, 14, 14, 14...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>Sentence: 1000</td>\n", | |
" <td>[They, left, after, a, tense, hour-long, stand...</td>\n", | |
" <td>[PRP, VBD, IN, DT, NN, JJ, NN, IN, NN, NNS, .]</td>\n", | |
" <td>[O, O, O, O, O, O, O, O, O, O, O]</td>\n", | |
" <td>[15684, 19543, 3537, 27550, 5642, 8050, 21291,...</td>\n", | |
" <td>[14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>Sentence: 10000</td>\n", | |
" <td>[U.N., relief, coordinator, Jan, Egeland, said...</td>\n", | |
" <td>[NNP, NN, NN, NNP, NNP, VBD, NNP, ,, NNP, ,, J...</td>\n", | |
" <td>[B-geo, O, O, B-per, I-per, O, B-tim, O, B-geo...</td>\n", | |
" <td>[9559, 18741, 13791, 21415, 18623, 33927, 8141...</td>\n", | |
" <td>[15, 14, 14, 12, 11, 14, 7, 14, 15, 14, 16, 14...</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Sentence # Word \\\n", | |
"0 Sentence: 1 [Thousands, of, demonstrators, have, marched, ... \n", | |
"1 Sentence: 10 [Iranian, officials, say, they, expect, to, ge... \n", | |
"2 Sentence: 100 [Helicopter, gunships, Saturday, pounded, mili... \n", | |
"3 Sentence: 1000 [They, left, after, a, tense, hour-long, stand... \n", | |
"4 Sentence: 10000 [U.N., relief, coordinator, Jan, Egeland, said... \n", | |
"\n", | |
" POS \\\n", | |
"0 [NNS, IN, NNS, VBP, VBN, IN, NNP, TO, VB, DT, ... \n", | |
"1 [JJ, NNS, VBP, PRP, VBP, TO, VB, NN, TO, JJ, J... \n", | |
"2 [NN, NNS, NNP, VBD, JJ, NNS, IN, DT, NNP, JJ, ... \n", | |
"3 [PRP, VBD, IN, DT, NN, JJ, NN, IN, NN, NNS, .] \n", | |
"4 [NNP, NN, NN, NNP, NNP, VBD, NNP, ,, NNP, ,, J... \n", | |
"\n", | |
" Tag \\\n", | |
"0 [O, O, O, O, O, O, B-geo, O, O, O, O, O, B-geo... \n", | |
"1 [B-gpe, O, O, O, O, O, O, O, O, O, O, O, O, O,... \n", | |
"2 [O, O, B-tim, O, O, O, O, O, B-geo, O, O, O, O... \n", | |
"3 [O, O, O, O, O, O, O, O, O, O, O] \n", | |
"4 [B-geo, O, O, B-per, I-per, O, B-tim, O, B-geo... \n", | |
"\n", | |
" Word_idx \\\n", | |
"0 [5292, 16691, 10027, 22673, 27636, 10500, 1569... \n", | |
"1 [19259, 15982, 10924, 1259, 33564, 25715, 2258... \n", | |
"2 [6517, 20597, 9446, 30647, 9086, 33177, 25074,... \n", | |
"3 [15684, 19543, 3537, 27550, 5642, 8050, 21291,... \n", | |
"4 [9559, 18741, 13791, 21415, 18623, 33927, 8141... \n", | |
"\n", | |
" Tag_idx \n", | |
"0 [14, 14, 14, 14, 14, 14, 15, 14, 14, 14, 14, 1... \n", | |
"1 [16, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 1... \n", | |
"2 [14, 14, 7, 14, 14, 14, 14, 14, 15, 14, 14, 14... \n", | |
"3 [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14] \n", | |
"4 [15, 14, 14, 12, 11, 14, 7, 14, 15, 14, 16, 14... " | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Заполнение пропусков\n", | |
"data_fillna = data.fillna(method='ffill', axis=0)\n", | |
"# Группируем\n", | |
"data_group = data_fillna.groupby(\n", | |
"['Sentence #'],as_index=False\n", | |
")['Word', 'POS', 'Tag', 'Word_idx', 'Tag_idx'].agg(lambda x: list(x))\n", | |
"# Смотрим\n", | |
"data_group.head()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dfc87665-ce74-447a-91f9-f9531d4b23b8", | |
"metadata": {}, | |
"source": [ | |
"### Разбиваем данные на выборки" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "99da5248-08b4-4696-8bce-8cd179d26b78", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n", | |
"from tensorflow.keras.utils import to_categorical" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "b3493c8e-5f2f-40a5-b86e-3a4b6f2e4681", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"train_tokens length: 32372 \n", | |
"train_tokens length: 32372 \n", | |
"test_tokens length: 4796 \n", | |
"test_tags: 4796 \n", | |
"val_tokens: 10791 \n", | |
"val_tags: 10791\n" | |
] | |
} | |
], | |
"source": [ | |
"def get_pad_train_test_val(data_group, data):\n", | |
"\n", | |
" #get max token and tag length\n", | |
" n_token = len(list(set(data['Word'].to_list())))\n", | |
" n_tag = len(list(set(data['Tag'].to_list())))\n", | |
"\n", | |
" #Pad tokens (X var) \n", | |
" tokens = data_group['Word_idx'].tolist()\n", | |
" maxlen = max([len(s) for s in tokens])\n", | |
" pad_tokens = pad_sequences(tokens, maxlen=maxlen, dtype='int32', padding='post', value= n_token - 1)\n", | |
"\n", | |
" #Pad Tags (y var) and convert it into one hot encoding\n", | |
" tags = data_group['Tag_idx'].tolist()\n", | |
" pad_tags = pad_sequences(tags, maxlen=maxlen, dtype='int32', padding='post', value= tag2idx[\"O\"])\n", | |
" n_tags = len(tag2idx)\n", | |
" pad_tags = [to_categorical(i, num_classes=n_tags) for i in pad_tags]\n", | |
" \n", | |
" #Split train, test and validation set\n", | |
" tokens_, test_tokens, tags_, test_tags = train_test_split(pad_tokens, pad_tags, test_size=0.1, train_size=0.9, random_state=2020)\n", | |
" train_tokens, val_tokens, train_tags, val_tags = train_test_split(tokens_,tags_,test_size = 0.25,train_size =0.75, random_state=2020)\n", | |
"\n", | |
" print(\n", | |
" 'train_tokens length:', len(train_tokens),\n", | |
" '\\ntrain_tokens length:', len(train_tokens),\n", | |
" '\\ntest_tokens length:', len(test_tokens),\n", | |
" '\\ntest_tags:', len(test_tags),\n", | |
" '\\nval_tokens:', len(val_tokens),\n", | |
" '\\nval_tags:', len(val_tags),\n", | |
" )\n", | |
" \n", | |
" return train_tokens, val_tokens, test_tokens, train_tags, val_tags, test_tags\n", | |
"\n", | |
"train_tokens, val_tokens, test_tokens, train_tags, val_tags, test_tags = get_pad_train_test_val(data_group, data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cafaa6b5-53bb-4f2f-8dc6-d5e069d75a2a", | |
"metadata": {}, | |
"source": [ | |
"### Строим нейронную сеть" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "84263348-9dbe-48bd-9f1e-b4fe23e54998", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import tensorflow\n", | |
"from tensorflow.keras import Sequential, Model, Input\n", | |
"from tensorflow.keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional\n", | |
"from tensorflow.keras.utils import plot_model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "a2660448-3360-4af9-ac48-cad89437c98b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# фиксируем состояния для воспроизводимости экспериментов\n", | |
"from numpy.random import seed\n", | |
"seed(1)\n", | |
"tensorflow.random.set_seed(2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "b9845bfc-8533-4295-b452-e3c362a80d9d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"input_dim: 35179 \n", | |
"output_dim: 64 \n", | |
"input_length: 104 \n", | |
"n_tags: 17\n" | |
] | |
} | |
], | |
"source": [ | |
"input_dim = len(list(set(data['Word'].to_list())))+1\n", | |
"output_dim = 64\n", | |
"input_length = max([len(s) for s in data_group['Word_idx'].tolist()])\n", | |
"n_tags = len(tag2idx)\n", | |
"print('input_dim: ', input_dim, '\\noutput_dim: ', output_dim, '\\ninput_length: ', input_length, '\\nn_tags: ', n_tags)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "1b7949ca-b632-4283-a93a-ccb7fc799028", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"def get_bilstm_lstm_model():\n", | |
" model = Sequential()\n", | |
"\n", | |
" # Слой Embedding\n", | |
" model.add(Embedding(input_dim=input_dim, output_dim=output_dim, input_length=input_length))\n", | |
"\n", | |
" # Слой bidirectional LSTM\n", | |
" model.add(Bidirectional(LSTM(units=output_dim, return_sequences=True, dropout=0.2, recurrent_dropout=0.2), merge_mode = 'concat'))\n", | |
"\n", | |
" # Слой LSTM\n", | |
" model.add(LSTM(units=output_dim, return_sequences=True, dropout=0.5, recurrent_dropout=0.5))\n", | |
"\n", | |
" # Слой timeDistributed Layer (обеспечивает выход формата many-to-many)\n", | |
" model.add(TimeDistributed(Dense(n_tags, activation=\"relu\")))\n", | |
"\n", | |
" #Optimiser \n", | |
" # adam = k.optimizers.Adam(lr=0.0005, beta_1=0.9, beta_2=0.999)\n", | |
"\n", | |
" # Compile model\n", | |
" model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", | |
" model.summary()\n", | |
" \n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fef6ad84-5fed-4f7f-b832-7a478c22fdfe", | |
"metadata": {}, | |
"source": [ | |
"#### Обучение" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "d38560e0-3df3-4c48-aa47-1686287da0ca", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_model(X, y, model):\n", | |
" loss = list()\n", | |
" for i in range(3):\n", | |
" # fit model for one epoch on this sequence\n", | |
" hist = model.fit(X, y, batch_size=128, verbose=1, epochs=1, validation_split=0.2)\n", | |
" loss.append(hist.history['loss'][0])\n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "2c09d124-34a5-4a28-993f-c4c1a3aefd85", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2021-09-21 15:13:15.345983: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", | |
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: \"sequential\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"embedding (Embedding) (None, 104, 64) 2251456 \n", | |
"_________________________________________________________________\n", | |
"bidirectional (Bidirectional (None, 104, 128) 66048 \n", | |
"_________________________________________________________________\n", | |
"lstm_1 (LSTM) (None, 104, 64) 49408 \n", | |
"_________________________________________________________________\n", | |
"time_distributed (TimeDistri (None, 104, 17) 1105 \n", | |
"=================================================================\n", | |
"Total params: 2,368,017\n", | |
"Trainable params: 2,368,017\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n", | |
"('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2021-09-21 15:13:15.857115: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"203/203 [==============================] - 64s 286ms/step - loss: 0.4522 - accuracy: 0.9629 - val_loss: 0.3592 - val_accuracy: 0.9681\n", | |
"203/203 [==============================] - 56s 274ms/step - loss: 0.2172 - accuracy: 0.9678 - val_loss: 0.1481 - val_accuracy: 0.9683\n", | |
"203/203 [==============================] - 57s 281ms/step - loss: 0.1357 - accuracy: 0.9685 - val_loss: 0.1204 - val_accuracy: 0.9694\n" | |
] | |
} | |
], | |
"source": [ | |
"results = pd.DataFrame()\n", | |
"model_bilstm_lstm = get_bilstm_lstm_model()\n", | |
"plot_model(model_bilstm_lstm)\n", | |
"results['with_add_lstm'] = train_model(train_tokens, np.array(train_tags), model_bilstm_lstm)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d04bb995-f9ea-4067-8d3c-dac3f93764b6", | |
"metadata": {}, | |
"source": [ | |
"### Смотрим результат работы сети " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "0ff158a1-3364-451e-bfaa-941217083127", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"predict = model_bilstm_lstm.predict(test_tokens)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "ca6860b5-6bd2-491d-8a19-67c26f7d528d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14])" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"num = 600\n", | |
"np.argmax(predict[num], axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "24567bd2-953d-4a1a-9c5d-419b347951d6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([15, 14, 13, 2, 14, 15, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 16,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,\n", | |
" 14, 14])" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.argmax(test_tags[num], axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a77595b0-3f91-4338-bd24-8e13c9835d34", | |
"metadata": {}, | |
"source": [ | |
"### Задание\n", | |
"- посчитать accuracy и confusion matrix;\n", | |
"- поизменять архитектуру нейронной сети, параметры обучения." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment