Skip to content

Instantly share code, notes, and snippets.

@MachineLearningIsEasy
Created October 28, 2021 12:32
Show Gist options
  • Save MachineLearningIsEasy/c82093fbcd61b6cccb36bc0418808142 to your computer and use it in GitHub Desktop.
Save MachineLearningIsEasy/c82093fbcd61b6cccb36bc0418808142 to your computer and use it in GitHub Desktop.
Using BERT in classification
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XXcdygSLmF7y"
},
"source": [
"![logo.png]()\n",
"\n",
"[перейти](https://www.bigdataschool.ru/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IZ6SNYq_tVVC"
},
"source": [
"# Классификация с BERT\n",
"\n",
"План блокнота:\n",
"\n",
"- загрузка датасета IMDB;\n",
"- загрузка BERT из TensorflowHUB;\n",
"- построение нейронной сети для классификации на основе BERT;\n",
"- обучение сети;\n",
"- сохраняем сеть, выполняем классификацию.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2PHBpLPuQdmK"
},
"source": [
"## Статья о BERT\n",
"\n",
"[BERT](https://arxiv.org/abs/1810.04805) и другие архитектуры нейронных сетей Transformer были безумно успешными в различных задачах NLP. Они вычисляют представления текста в векторном пространстве. Семейство моделей BERT использует архитектуру кодировщика Transformer для обработки каждого токена текста в полном контексте всех токенов до и после.\n",
"\n",
"Модели BERT обычно предварительно обучаются на большом корпусе текста, а затем настраиваются для специфической задачи.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SCjmX4zTCkRK"
},
"source": [
"## Настройка среды"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "q-YbjCkzw0yU",
"outputId": "896f4500-6d47-4638-c3ca-421d0e1479b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 4.4 MB 38.5 MB/s \n",
"\u001b[?25h"
]
}
],
"source": [
"# A dependency of the preprocessing for BERT inputs\n",
"!pip install -q -U tensorflow-text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b-P1ZOA0FkVJ",
"outputId": "3909ca9d-4bb6-44be-ffd7-3dd1aeb90813"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[?25l\r",
"\u001b[K |▏ | 10 kB 22.8 MB/s eta 0:00:01\r",
"\u001b[K |▍ | 20 kB 27.9 MB/s eta 0:00:01\r",
"\u001b[K |▋ | 30 kB 30.6 MB/s eta 0:00:01\r",
"\u001b[K |▊ | 40 kB 33.7 MB/s eta 0:00:01\r",
"\u001b[K |█ | 51 kB 37.0 MB/s eta 0:00:01\r",
"\u001b[K |█▏ | 61 kB 35.9 MB/s eta 0:00:01\r",
"\u001b[K |█▎ | 71 kB 33.5 MB/s eta 0:00:01\r",
"\u001b[K |█▌ | 81 kB 33.4 MB/s eta 0:00:01\r",
"\u001b[K |█▊ | 92 kB 34.6 MB/s eta 0:00:01\r",
"\u001b[K |█▉ | 102 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██ | 112 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██▎ | 122 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██▍ | 133 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██▋ | 143 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██▉ | 153 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███ | 163 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███▏ | 174 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███▍ | 184 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███▌ | 194 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███▊ | 204 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████ | 215 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████ | 225 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████▎ | 235 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████▌ | 245 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████▋ | 256 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████▉ | 266 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████ | 276 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████▏ | 286 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████▍ | 296 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████▋ | 307 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████▊ | 317 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████ | 327 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████▏ | 337 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████▎ | 348 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████▌ | 358 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████▊ | 368 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████▉ | 378 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████ | 389 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████▎ | 399 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████▍ | 409 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████▋ | 419 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████▉ | 430 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████ | 440 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████▏ | 450 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████▍ | 460 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████▌ | 471 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████▊ | 481 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████ | 491 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████ | 501 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████▎ | 512 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████▌ | 522 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████▋ | 532 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████▉ | 542 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████ | 552 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████▏ | 563 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████▍ | 573 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████▋ | 583 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████▊ | 593 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████ | 604 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████▏ | 614 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████▎ | 624 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████▌ | 634 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████▊ | 645 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████▉ | 655 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████ | 665 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████▎ | 675 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████▍ | 686 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████▋ | 696 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████▉ | 706 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████ | 716 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████▏ | 727 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████▍ | 737 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████▌ | 747 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████▊ | 757 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████ | 768 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████ | 778 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████▎ | 788 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████▌ | 798 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████▋ | 808 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████▉ | 819 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████ | 829 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████▏ | 839 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████▍ | 849 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████▋ | 860 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████▊ | 870 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████ | 880 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████▏ | 890 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████▎ | 901 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████▌ | 911 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████▊ | 921 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████▉ | 931 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████ | 942 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████▎ | 952 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████▍ | 962 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████▋ | 972 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████▉ | 983 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████ | 993 kB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████▏ | 1.0 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████▍ | 1.0 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████▋ | 1.0 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████▊ | 1.0 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████ | 1.0 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████▏ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████▎ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████▌ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████▊ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████▉ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████▎ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████▍ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████▋ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████▉ | 1.1 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████▏ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████▍ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████▌ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████▊ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████▎ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████▌ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████▋ | 1.2 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████▉ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████▏ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████▍ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████▋ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████▊ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████▏ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████▎ | 1.3 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████▌ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████▊ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████▉ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████▎ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████▍ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████▋ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████▉ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████▏ | 1.4 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████▍ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████▌ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████▊ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████▎ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████▌ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████▋ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████▉ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████ | 1.5 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████▏ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████▍ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████▋ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████▊ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████▏ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████▎ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████▌ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████▊ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |█████████████████████████████▉ | 1.6 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████████ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████████▎ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████████▍ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████████▋ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |██████████████████████████████▉ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████████ | 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████████▏| 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████████▍| 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████████▌| 1.7 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |███████████████████████████████▊| 1.8 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████████| 1.8 MB 35.4 MB/s eta 0:00:01\r",
"\u001b[K |████████████████████████████████| 1.8 MB 35.4 MB/s \n",
"\u001b[K |████████████████████████████████| 99 kB 9.7 MB/s \n",
"\u001b[K |████████████████████████████████| 1.1 MB 39.0 MB/s \n",
"\u001b[K |████████████████████████████████| 352 kB 44.8 MB/s \n",
"\u001b[K |████████████████████████████████| 43 kB 1.8 MB/s \n",
"\u001b[K |████████████████████████████████| 37.1 MB 51 kB/s \n",
"\u001b[K |████████████████████████████████| 211 kB 58.1 MB/s \n",
"\u001b[K |████████████████████████████████| 1.2 MB 36.3 MB/s \n",
"\u001b[K |████████████████████████████████| 636 kB 34.7 MB/s \n",
"\u001b[K |████████████████████████████████| 90 kB 9.5 MB/s \n",
"\u001b[?25h Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
]
}
],
"source": [
"!pip install -q tf-models-official"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_XgTpm9ZxoN9"
},
"outputs": [],
"source": [
"import os\n",
"import shutil\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import tensorflow_text as text\n",
"from official.nlp import optimization # to create AdamW optimizer\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"tf.get_logger().setLevel('ERROR')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q6MugfEgDRpY"
},
"source": [
"## Тональность отзывов IMDB\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vnvd4mrtPHHV"
},
"source": [
"### Загрузка IMDB dataset\n",
"\n",
"Загрузим данные, посмотрим на их структуру"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pOdqCMoQDRJL",
"outputId": "e3ffcbc5-71a6-4cf4-ee04-77d9fe8ed3bf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
"84131840/84125825 [==============================] - 3s 0us/step\n",
"84140032/84125825 [==============================] - 3s 0us/step\n"
]
}
],
"source": [
"url = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'\n",
"\n",
"dataset = tf.keras.utils.get_file('aclImdb_v1.tar.gz', url,\n",
" untar=True, cache_dir='.',\n",
" cache_subdir='')\n",
"\n",
"dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')\n",
"\n",
"train_dir = os.path.join(dataset_dir, 'train')\n",
"\n",
"# remove unused folders to make it easier to load the data\n",
"remove_dir = os.path.join(train_dir, 'unsup')\n",
"shutil.rmtree(remove_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MbQq7ZU20aYL",
"outputId": "95948c01-d1d4-4d1a-d0f8-b943fed081a6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"labeledBow.feat pos\t\turls_neg.txt urls_unsup.txt\n",
"neg\t\t unsupBow.feat\turls_pos.txt\n"
]
}
],
"source": [
"!ls aclImdb/train"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lN9lWCYfPo7b"
},
"source": [
"Используя `text_dataset_from_directory` создадим объект `tf.data.Dataset`.\n",
"\n",
"IMDB dataset уже разделен на выборки train и test. Выделим выборку для валидации. Разобьем выборку для обучения в пропорции 80:20."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6IwI_2bcIeX8",
"outputId": "d2cd206a-25ba-486e-abc1-18328242d8ed"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 25000 files belonging to 2 classes.\n",
"Using 20000 files for training.\n",
"Found 25000 files belonging to 2 classes.\n",
"Using 5000 files for validation.\n",
"Found 25000 files belonging to 2 classes.\n"
]
}
],
"source": [
"AUTOTUNE = tf.data.AUTOTUNE\n",
"batch_size = 32\n",
"seed = 42\n",
"\n",
"raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
" 'aclImdb/train',\n",
" batch_size=batch_size,\n",
" validation_split=0.2,\n",
" subset='training',\n",
" seed=seed)\n",
"\n",
"class_names = raw_train_ds.class_names\n",
"train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
"\n",
"val_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
" 'aclImdb/train',\n",
" batch_size=batch_size,\n",
" validation_split=0.2,\n",
" subset='validation',\n",
" seed=seed)\n",
"\n",
"val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
"\n",
"test_ds = tf.keras.preprocessing.text_dataset_from_directory(\n",
" 'aclImdb/test',\n",
" batch_size=batch_size)\n",
"\n",
"test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HGm10A5HRGXp"
},
"source": [
"Примеры"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JuxDkcvVIoev",
"outputId": "8615615d-1799-47fa-a520-b5ac755b759b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Review: b'\"Pandemonium\" is a horror movie spoof that comes off more stupid than funny. Believe me when I tell you, I love comedies. Especially comedy spoofs. \"Airplane\", \"The Naked Gun\" trilogy, \"Blazing Saddles\", \"High Anxiety\", and \"Spaceballs\" are some of my favorite comedies that spoof a particular genre. \"Pandemonium\" is not up there with those films. Most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\\'t all that funny. There are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\\'s all this film has going for it. Geez, \"Scream\" had more laughs than this film and that was more of a horror film. How bizarre is that?<br /><br />*1/2 (out of four)'\n",
"Label : 0 (neg)\n",
"Review: b\"David Mamet is a very interesting and a very un-equal director. His first movie 'House of Games' was the one I liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.<br /><br />So is 'Homicide' which from the title tries to set the mind of the viewer to the usual crime drama. The principal characters are two cops, one Jewish and one Irish who deal with a racially charged area. The murder of an old Jewish shop owner who proves to be an ancient veteran of the Israeli Independence war triggers the Jewish identity in the mind and heart of the Jewish detective.<br /><br />This is were the flaws of the film are the more obvious. The process of awakening is theatrical and hard to believe, the group of Jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. The end of the film itself is Mamet-like smart, but disappoints from a human emotional perspective.<br /><br />Joe Mantegna and William Macy give strong performances, but the flaws of the story are too evident to be easily compensated.\"\n",
"Label : 0 (neg)\n",
"Review: b'Great documentary about the lives of NY firefighters during the worst terrorist attack of all time.. That reason alone is why this should be a must see collectors item.. What shocked me was not only the attacks, but the\"High Fat Diet\" and physical appearance of some of these firefighters. I think a lot of Doctors would agree with me that,in the physical shape they were in, some of these firefighters would NOT of made it to the 79th floor carrying over 60 lbs of gear. Having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. The French have a history of making great documentary\\'s and that is what this is, a Great Documentary.....'\n",
"Label : 1 (pos)\n"
]
}
],
"source": [
"for text_batch, label_batch in train_ds.take(1):\n",
" for i in range(3):\n",
" print(f'Review: {text_batch.numpy()[i]}')\n",
" label = label_batch.numpy()[i]\n",
" print(f'Label : {label} ({class_names[label]})')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dX8FtlpGJRE6"
},
"source": [
"## Загзурим модели из TensorFlow Hub\n",
"\n",
"Доступные BERT модели.\n",
"\n",
" - [BERT-Base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3), (https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3)(https://tfhub.dev/google/collections/bert/1).\n",
" - [Small BERTs](https://tfhub.dev/google/collections/bert/1).\n",
" - [ALBERT](https://tfhub.dev/google/collections/albert/1).\n",
" - [BERT Experts](https://tfhub.dev/google/collections/experts/bert/1).\n",
" - [Electra](https://tfhub.dev/google/collections/electra/1) [[base](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1), [large](https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_large/1)] \n",
"\n",
"\n",
"\n",
"Начнем с Small BERT.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y8_ctG55-uTX",
"outputId": "8eb09fd7-f7a4-4e5b-dd81-00c2a6e27f11"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BERT model selected : https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1\n",
"Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\n"
]
}
],
"source": [
"#@title Choose a BERT model to fine-tune\n",
"\n",
"bert_model_name = 'small_bert/bert_en_uncased_L-4_H-512_A-8' #@param [\"bert_en_uncased_L-12_H-768_A-12\", \"bert_en_cased_L-12_H-768_A-12\", \"bert_multi_cased_L-12_H-768_A-12\", \"small_bert/bert_en_uncased_L-2_H-128_A-2\", \"small_bert/bert_en_uncased_L-2_H-256_A-4\", \"small_bert/bert_en_uncased_L-2_H-512_A-8\", \"small_bert/bert_en_uncased_L-2_H-768_A-12\", \"small_bert/bert_en_uncased_L-4_H-128_A-2\", \"small_bert/bert_en_uncased_L-4_H-256_A-4\", \"small_bert/bert_en_uncased_L-4_H-512_A-8\", \"small_bert/bert_en_uncased_L-4_H-768_A-12\", \"small_bert/bert_en_uncased_L-6_H-128_A-2\", \"small_bert/bert_en_uncased_L-6_H-256_A-4\", \"small_bert/bert_en_uncased_L-6_H-512_A-8\", \"small_bert/bert_en_uncased_L-6_H-768_A-12\", \"small_bert/bert_en_uncased_L-8_H-128_A-2\", \"small_bert/bert_en_uncased_L-8_H-256_A-4\", \"small_bert/bert_en_uncased_L-8_H-512_A-8\", \"small_bert/bert_en_uncased_L-8_H-768_A-12\", \"small_bert/bert_en_uncased_L-10_H-128_A-2\", \"small_bert/bert_en_uncased_L-10_H-256_A-4\", \"small_bert/bert_en_uncased_L-10_H-512_A-8\", \"small_bert/bert_en_uncased_L-10_H-768_A-12\", \"small_bert/bert_en_uncased_L-12_H-128_A-2\", \"small_bert/bert_en_uncased_L-12_H-256_A-4\", \"small_bert/bert_en_uncased_L-12_H-512_A-8\", \"small_bert/bert_en_uncased_L-12_H-768_A-12\", \"albert_en_base\", \"electra_small\", \"electra_base\", \"experts_pubmed\", \"experts_wiki_books\", \"talking-heads_base\"]\n",
"\n",
"map_name_to_handle = {\n",
" 'bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',\n",
" 'bert_en_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',\n",
" 'bert_multi_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',\n",
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',\n",
" 'albert_en_base':\n",
" 'https://tfhub.dev/tensorflow/albert_en_base/2',\n",
" 'electra_small':\n",
" 'https://tfhub.dev/google/electra_small/2',\n",
" 'electra_base':\n",
" 'https://tfhub.dev/google/electra_base/2',\n",
" 'experts_pubmed':\n",
" 'https://tfhub.dev/google/experts/bert/pubmed/2',\n",
" 'experts_wiki_books':\n",
" 'https://tfhub.dev/google/experts/bert/wiki_books/2',\n",
" 'talking-heads_base':\n",
" 'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',\n",
"}\n",
"\n",
"map_model_to_preprocess = {\n",
" 'bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'bert_en_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-2_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-4_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-4_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-4_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-4_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-6_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-6_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-6_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-6_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-8_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-8_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-8_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-8_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-10_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-10_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-10_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-10_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-12_H-128_A-2':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-12_H-256_A-4':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-12_H-512_A-8':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'small_bert/bert_en_uncased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'bert_multi_cased_L-12_H-768_A-12':\n",
" 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3',\n",
" 'albert_en_base':\n",
" 'https://tfhub.dev/tensorflow/albert_en_preprocess/3',\n",
" 'electra_small':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'electra_base':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'experts_pubmed':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'experts_wiki_books':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
" 'talking-heads_base':\n",
" 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',\n",
"}\n",
"\n",
"tfhub_handle_encoder = map_name_to_handle[bert_model_name]\n",
"tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]\n",
"\n",
"print(f'BERT model selected : {tfhub_handle_encoder}')\n",
"print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7WrcxxTRDdHi"
},
"source": [
"## Модель препроцессинга\n",
"\n",
"Текст на входе в нейронную сеть нужно преобразовать в токены и далее в векторные представления перед подачей в BERT. TensorFlow Hub предоставляет соответствующую модель предварительной обработки для каждой из рассмотренных выше моделей BERT, которая реализует это преобразование с помощью операций TF из библиотеки TF.text. Нет необходимости запускать чистый код Python вне вашей модели TensorFlow для предварительной обработки текста.\n",
"\n",
"Модель предварительной обработки должна быть той, на которую ссылается документация модели BERT, которую вы можете прочитать по указанному выше URL-адресу. Для моделей BERT из раскрывающегося списка выше модель предварительной обработки выбирается автоматически.\n",
"\n",
"Примечание. Вы загрузите модель предварительной обработки в [hub.KerasLayer] (https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer), чтобы составить свою точно настроенную модель. Это предпочтительный API для загрузки SavedModel в стиле TF2 из TF Hub в модель Keras."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0SQi-jWd_jzq"
},
"outputs": [],
"source": [
"bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x4naBiEE_cZX"
},
"source": [
"Выполним препроцессинг текста"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "r9-zCzJpnuwS",
"outputId": "a2cd0182-7905-47e1-805c-8e440828ba29"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Keys : ['input_word_ids', 'input_mask', 'input_type_ids']\n",
"Shape : (1, 128)\n",
"Word Ids : [ 101 2023 2003 2107 2019 6429 3185 999 102 0 0 0]\n",
"Input Mask : [1 1 1 1 1 1 1 1 1 0 0 0]\n",
"Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]\n"
]
}
],
"source": [
"text_test = ['this is such an amazing movie!']\n",
"text_preprocessed = bert_preprocess_model(text_test)\n",
"\n",
"print(f'Keys : {list(text_preprocessed.keys())}')\n",
"print(f'Shape : {text_preprocessed[\"input_word_ids\"].shape}')\n",
"print(f'Word Ids : {text_preprocessed[\"input_word_ids\"][0, :12]}')\n",
"print(f'Input Mask : {text_preprocessed[\"input_mask\"][0, :12]}')\n",
"print(f'Type Ids : {text_preprocessed[\"input_type_ids\"][0, :12]}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EqL7ihkN_862"
},
"source": [
"Наблюдаем 3 выхода (`input_words_id`, `input_mask` and `input_type_ids`).\n",
"\n",
"Обратите внимание:\n",
"- Вход ограничен 128 токенами. \n",
"- `input_type_ids` равно (0) так как на входе одно предложение. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DKnLPSEmtp9i"
},
"source": [
"## Используем BERT"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tXxYpK8ixL34"
},
"outputs": [],
"source": [
"bert_model = hub.KerasLayer(tfhub_handle_encoder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_OoF9mebuSZc"
},
"outputs": [],
"source": [
"bert_results = bert_model(text_preprocessed)\n",
"\n",
"print(f'Loaded BERT: {tfhub_handle_encoder}')\n",
"print(f'Pooled Outputs Shape:{bert_results[\"pooled_output\"].shape}')\n",
"print(f'Pooled Outputs Values:{bert_results[\"pooled_output\"][0, :12]}')\n",
"print(f'Sequence Outputs Shape:{bert_results[\"sequence_output\"].shape}')\n",
"print(f'Sequence Outputs Values:{bert_results[\"sequence_output\"][0, :12]}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sm61jDrezAll"
},
"source": [
"The BERT models return a map with 3 important keys: `pooled_output`, `sequence_output`, `encoder_outputs`:\n",
"\n",
"- `pooled_output` represents each input sequence as a whole. The shape is `[batch_size, H]`. You can think of this as an embedding for the entire movie review.\n",
"- `sequence_output` represents each input token in the context. The shape is `[batch_size, seq_length, H]`. You can think of this as a contextual embedding for every token in the movie review.\n",
"- `encoder_outputs` are the intermediate activations of the `L` Transformer blocks. `outputs[\"encoder_outputs\"][i]` is a Tensor of shape `[batch_size, seq_length, 1024]` with the outputs of the i-th Transformer block, for `0 <= i < L`. The last value of the list is equal to `sequence_output`.\n",
"\n",
"For the fine-tuning you are going to use the `pooled_output` array."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pDNKfAXbDnJH"
},
"source": [
"## Строим нейросеть\n",
"\n",
"Препроцессинговая модель -> Bert -> Dropout -> Dense\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aksj743St9ga"
},
"outputs": [],
"source": [
"def build_classifier_model():\n",
" text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
" preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')\n",
" encoder_inputs = preprocessing_layer(text_input)\n",
" encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')\n",
" outputs = encoder(encoder_inputs)\n",
" net = outputs['pooled_output']\n",
" net = tf.keras.layers.Dropout(0.1)(net)\n",
" net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)\n",
" return tf.keras.Model(text_input, net)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zs4yhFraBuGQ"
},
"source": [
"Выполним inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mGMF8AZcB2Zy"
},
"outputs": [],
"source": [
"classifier_model = build_classifier_model()\n",
"bert_raw_result = classifier_model(tf.constant(text_test))\n",
"print(tf.sigmoid(bert_raw_result))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZTUzNV2JE2G3"
},
"source": [
"Модель требует обучения\n",
"\n",
"Посмотрим на структуру"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0EmzyHZXKIpm"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WbUWoZMwc302"
},
"source": [
"## Обучаем модель\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WpJ3xcwDT56v"
},
"source": [
"### Функция потерь (loss)\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OWPOZE-L3AgE"
},
"outputs": [],
"source": [
"loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
"metrics = tf.metrics.BinaryAccuracy()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "77psrpfzbxtp"
},
"source": [
"### Оптимизатор\n",
"Для дообучения лучше использовать тот же оптимизатор, на котором обучалась BERT (Adam, https://arxiv.org/abs/1711.05101)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P9eP2y9dbw32"
},
"outputs": [],
"source": [
"epochs = 5\n",
"steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()\n",
"num_train_steps = steps_per_epoch * epochs\n",
"num_warmup_steps = int(0.1*num_train_steps)\n",
"\n",
"init_lr = 3e-5\n",
"optimizer = optimization.create_optimizer(init_lr=init_lr,\n",
" num_train_steps=num_train_steps,\n",
" num_warmup_steps=num_warmup_steps,\n",
" optimizer_type='adamw')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SqlarlpC_v0g"
},
"source": [
"### Компилируем модель и обучаем\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-7GPDhR98jsD"
},
"outputs": [],
"source": [
"classifier_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HtfDFAnN_Neu"
},
"outputs": [],
"source": [
"print(f'Training model with {tfhub_handle_encoder}')\n",
"history = classifier_model.fit(x=train_ds,\n",
" validation_data=val_ds,\n",
" epochs=epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uBthMlTSV8kn"
},
"source": [
"### Оцениваем качество сети\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "slqB-urBV9sP"
},
"outputs": [],
"source": [
"loss, accuracy = classifier_model.evaluate(test_ds)\n",
"\n",
"print(f'Loss: {loss}')\n",
"print(f'Accuracy: {accuracy}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uttWpgmSfzq9"
},
"source": [
"### Смотрим на метрики и функцию потерь\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fiythcODf0xo"
},
"outputs": [],
"source": [
"history_dict = history.history\n",
"print(history_dict.keys())\n",
"\n",
"acc = history_dict['binary_accuracy']\n",
"val_acc = history_dict['val_binary_accuracy']\n",
"loss = history_dict['loss']\n",
"val_loss = history_dict['val_loss']\n",
"\n",
"epochs = range(1, len(acc) + 1)\n",
"fig = plt.figure(figsize=(10, 6))\n",
"fig.tight_layout()\n",
"\n",
"plt.subplot(2, 1, 1)\n",
"# \"bo\" is for \"blue dot\"\n",
"plt.plot(epochs, loss, 'r', label='Training loss')\n",
"# b is for \"solid blue line\"\n",
"plt.plot(epochs, val_loss, 'b', label='Validation loss')\n",
"plt.title('Training and validation loss')\n",
"# plt.xlabel('Epochs')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"\n",
"plt.subplot(2, 1, 2)\n",
"plt.plot(epochs, acc, 'r', label='Training acc')\n",
"plt.plot(epochs, val_acc, 'b', label='Validation acc')\n",
"plt.title('Training and validation accuracy')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Accuracy')\n",
"plt.legend(loc='lower right')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rtn7jewb6dg4"
},
"source": [
"## Сохраняем сеть для inference\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ShcvqJAgVera"
},
"outputs": [],
"source": [
"dataset_name = 'imdb'\n",
"saved_model_path = './{}_bert'.format(dataset_name.replace('/', '_'))\n",
"\n",
"classifier_model.save(saved_model_path, include_optimizer=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gUEWVskZjEF0"
},
"outputs": [],
"source": [
"reloaded_model = tf.saved_model.load(saved_model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oyTappHTvNCz"
},
"source": [
"Тестим сеть на кастомных примеров"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VBWzH6exlCPS"
},
"outputs": [],
"source": [
"def print_my_examples(inputs, results):\n",
" result_for_printing = \\\n",
" [f'input: {inputs[i]:<30} : score: {results[i][0]:.6f}'\n",
" for i in range(len(inputs))]\n",
" print(*result_for_printing, sep='\\n')\n",
" print()\n",
"\n",
"\n",
"examples = [\n",
" 'this is such an amazing movie!',\n",
" 'The movie was great!',\n",
" 'The movie was meh.',\n",
" 'The movie was okish.',\n",
" 'The movie was terrible...'\n",
"]\n",
"\n",
"reloaded_results = tf.sigmoid(reloaded_model(tf.constant(examples)))\n",
"original_results = tf.sigmoid(classifier_model(tf.constant(examples)))\n",
"\n",
"print('Results from the saved model:')\n",
"print_my_examples(examples, reloaded_results)\n",
"print('Results from the model in memory:')\n",
"print_my_examples(examples, original_results)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3cOmih754Y_M"
},
"source": [
"Сохраняем для использования в [TF Serving](https://www.tensorflow.org/tfx/guide/serving)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0FdVD3973S-O"
},
"outputs": [],
"source": [
"serving_results = reloaded_model \\\n",
" .signatures['serving_default'](tf.constant(examples))\n",
"\n",
"serving_results = tf.sigmoid(serving_results['classifier'])\n",
"\n",
"print_my_examples(examples, serving_results)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4DmhXTlQzg4q"
},
"source": [
"# Задание\n",
"\n",
"Построить нейросеть для классификации тегов вопросов на stavkoverflow\n",
"https://www.tensorflow.org/tutorials/load_data/text"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "2_bert_classification.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment