Last active
April 13, 2022 11:56
-
-
Save jamm1985/7fd9f2aecee17614af54aa805191446f to your computer and use it in GitHub Desktop.
Lab_15_Intro_to_ML_Classification_part_II.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Lab_15_Intro_to_ML_Classification_part_II.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyPeHhv6+euwp/n4WwUNatil", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jamm1985/7fd9f2aecee17614af54aa805191446f/lab_15_intro_to_ml_classification_part_ii.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Видео лабораторной: https://youtu.be/lkiFy6LQnSk\n", | |
"\n", | |
"TG: https://t.me/data_science_news\n", | |
"\n", | |
"\n", | |
"\n", | |
"---" | |
], | |
"metadata": { | |
"id": "0t6oCX9hxKL0" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Введение в ML классификация\n", | |
"\n", | |
"### Часть 1\n", | |
"- Логистическая регрессия (бинарный вариант)\n", | |
"- Оценка качества классификации для бинарного случая\n", | |
"\n", | |
"### Часть 2\n", | |
"- Логистическая регрессия для трёх и более классов\n", | |
"- Линейный/Квадратный дискриминативный анализ \n", | |
"- Оценка качества классификации для бинарного случая и выбор модели" | |
], | |
"metadata": { | |
"id": "AJx5fQI7JsbZ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "oIq0XfXXrfqY" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"\n", | |
"import matplotlib.pylab as plt\n", | |
"\n", | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", | |
"from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.model_selection import cross_val_score\n", | |
"from sklearn.metrics import classification_report\n", | |
"from sklearn.metrics import confusion_matrix\n", | |
"from sklearn.metrics import precision_score\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"from sklearn.metrics import roc_curve, auc\n", | |
"from sklearn.preprocessing import label_binarize\n", | |
"\n", | |
"from itertools import cycle" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Данные (геометрические параметры разных сортов пшеницы)" | |
], | |
"metadata": { | |
"id": "b7hvIP10L55h" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# download dataset\n", | |
"# https://archive.ics.uci.edu/ml/datasets/seeds\n", | |
"!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\n", | |
"!head seeds_dataset.txt\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6apfctsWrvdA", | |
"outputId": "77e6f198-abbc-43e6-f9e2-1ea0a7c0243e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"--2022-04-13 06:01:15-- https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\n", | |
"Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n", | |
"Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 9300 (9.1K) [application/x-httpd-php]\n", | |
"Saving to: ‘seeds_dataset.txt’\n", | |
"\n", | |
"seeds_dataset.txt 100%[===================>] 9.08K --.-KB/s in 0s \n", | |
"\n", | |
"2022-04-13 06:01:15 (99.1 MB/s) - ‘seeds_dataset.txt’ saved [9300/9300]\n", | |
"\n", | |
"15.26\t14.84\t0.871\t5.763\t3.312\t2.221\t5.22\t1\n", | |
"14.88\t14.57\t0.8811\t5.554\t3.333\t1.018\t4.956\t1\n", | |
"14.29\t14.09\t0.905\t5.291\t3.337\t2.699\t4.825\t1\n", | |
"13.84\t13.94\t0.8955\t5.324\t3.379\t2.259\t4.805\t1\n", | |
"16.14\t14.99\t0.9034\t5.658\t3.562\t1.355\t5.175\t1\n", | |
"14.38\t14.21\t0.8951\t5.386\t3.312\t2.462\t4.956\t1\n", | |
"14.69\t14.49\t0.8799\t5.563\t3.259\t3.586\t5.219\t1\n", | |
"14.11\t14.1\t0.8911\t5.42\t3.302\t2.7\t\t5\t\t1\n", | |
"16.63\t15.46\t0.8747\t6.053\t3.465\t2.04\t5.877\t1\n", | |
"16.44\t15.25\t0.888\t5.884\t3.505\t1.969\t5.533\t1\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# load the data\n", | |
"DATA = pd.read_csv(\n", | |
" \"https://archive.ics.uci.edu/ml/machine-learning-databases/00236/seeds_dataset.txt\",\n", | |
" sep='\\t',\n", | |
" header=None,\n", | |
" names=[\"area\", \"perimeter\", \"compactness\", \"length\", \"width\", \"asymmetry\", \"length_groove\", \"class\"],\n", | |
" on_bad_lines='skip')\n", | |
"DATA" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 423 | |
}, | |
"id": "sI7VPb84nF0I", | |
"outputId": "977e47c3-a445-4c73-f984-f6de95da93cd" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" area perimeter compactness length width asymmetry length_groove \\\n", | |
"0 15.26 14.84 0.8710 5.763 3.312 2.221 5.220 \n", | |
"1 14.88 14.57 0.8811 5.554 3.333 1.018 4.956 \n", | |
"2 14.29 14.09 0.9050 5.291 3.337 2.699 4.825 \n", | |
"3 13.84 13.94 0.8955 5.324 3.379 2.259 4.805 \n", | |
"4 16.14 14.99 0.9034 5.658 3.562 1.355 5.175 \n", | |
".. ... ... ... ... ... ... ... \n", | |
"205 12.19 13.20 0.8783 5.137 2.981 3.631 4.870 \n", | |
"206 11.23 12.88 0.8511 5.140 2.795 4.325 5.003 \n", | |
"207 13.20 13.66 0.8883 5.236 3.232 8.315 5.056 \n", | |
"208 11.84 13.21 0.8521 5.175 2.836 3.598 5.044 \n", | |
"209 12.30 13.34 0.8684 5.243 2.974 5.637 5.063 \n", | |
"\n", | |
" class \n", | |
"0 1.0 \n", | |
"1 1.0 \n", | |
"2 1.0 \n", | |
"3 1.0 \n", | |
"4 1.0 \n", | |
".. ... \n", | |
"205 3.0 \n", | |
"206 3.0 \n", | |
"207 3.0 \n", | |
"208 3.0 \n", | |
"209 3.0 \n", | |
"\n", | |
"[210 rows x 8 columns]" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-b5862fc4-4d38-4e60-b8d6-898d46a1447c\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>area</th>\n", | |
" <th>perimeter</th>\n", | |
" <th>compactness</th>\n", | |
" <th>length</th>\n", | |
" <th>width</th>\n", | |
" <th>asymmetry</th>\n", | |
" <th>length_groove</th>\n", | |
" <th>class</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>15.26</td>\n", | |
" <td>14.84</td>\n", | |
" <td>0.8710</td>\n", | |
" <td>5.763</td>\n", | |
" <td>3.312</td>\n", | |
" <td>2.221</td>\n", | |
" <td>5.220</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>14.88</td>\n", | |
" <td>14.57</td>\n", | |
" <td>0.8811</td>\n", | |
" <td>5.554</td>\n", | |
" <td>3.333</td>\n", | |
" <td>1.018</td>\n", | |
" <td>4.956</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>14.29</td>\n", | |
" <td>14.09</td>\n", | |
" <td>0.9050</td>\n", | |
" <td>5.291</td>\n", | |
" <td>3.337</td>\n", | |
" <td>2.699</td>\n", | |
" <td>4.825</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>13.84</td>\n", | |
" <td>13.94</td>\n", | |
" <td>0.8955</td>\n", | |
" <td>5.324</td>\n", | |
" <td>3.379</td>\n", | |
" <td>2.259</td>\n", | |
" <td>4.805</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>16.14</td>\n", | |
" <td>14.99</td>\n", | |
" <td>0.9034</td>\n", | |
" <td>5.658</td>\n", | |
" <td>3.562</td>\n", | |
" <td>1.355</td>\n", | |
" <td>5.175</td>\n", | |
" <td>1.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>205</th>\n", | |
" <td>12.19</td>\n", | |
" <td>13.20</td>\n", | |
" <td>0.8783</td>\n", | |
" <td>5.137</td>\n", | |
" <td>2.981</td>\n", | |
" <td>3.631</td>\n", | |
" <td>4.870</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>206</th>\n", | |
" <td>11.23</td>\n", | |
" <td>12.88</td>\n", | |
" <td>0.8511</td>\n", | |
" <td>5.140</td>\n", | |
" <td>2.795</td>\n", | |
" <td>4.325</td>\n", | |
" <td>5.003</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>207</th>\n", | |
" <td>13.20</td>\n", | |
" <td>13.66</td>\n", | |
" <td>0.8883</td>\n", | |
" <td>5.236</td>\n", | |
" <td>3.232</td>\n", | |
" <td>8.315</td>\n", | |
" <td>5.056</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>208</th>\n", | |
" <td>11.84</td>\n", | |
" <td>13.21</td>\n", | |
" <td>0.8521</td>\n", | |
" <td>5.175</td>\n", | |
" <td>2.836</td>\n", | |
" <td>3.598</td>\n", | |
" <td>5.044</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>209</th>\n", | |
" <td>12.30</td>\n", | |
" <td>13.34</td>\n", | |
" <td>0.8684</td>\n", | |
" <td>5.243</td>\n", | |
" <td>2.974</td>\n", | |
" <td>5.637</td>\n", | |
" <td>5.063</td>\n", | |
" <td>3.0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>210 rows × 8 columns</p>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b5862fc4-4d38-4e60-b8d6-898d46a1447c')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-b5862fc4-4d38-4e60-b8d6-898d46a1447c button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-b5862fc4-4d38-4e60-b8d6-898d46a1447c');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"DATA.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "vKXjajLiIEkH", | |
"outputId": "1c09125e-a7a0-494e-9502-10b003d00729" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" area perimeter compactness length width \\\n", | |
"count 210.000000 210.000000 207.000000 210.000000 209.000000 \n", | |
"mean 14.847524 14.559286 0.871280 5.563918 3.281440 \n", | |
"std 2.909699 1.305959 0.023306 0.719594 0.419907 \n", | |
"min 10.590000 12.410000 0.808100 0.818900 2.630000 \n", | |
"25% 12.270000 13.450000 0.857700 5.244750 2.956000 \n", | |
"50% 14.355000 14.320000 0.873500 5.518000 3.245000 \n", | |
"75% 17.305000 15.715000 0.887650 5.979750 3.566000 \n", | |
"max 21.180000 17.250000 0.918300 6.675000 5.325000 \n", | |
"\n", | |
" asymmetry length_groove class \n", | |
"count 210.000000 206.000000 206.000000 \n", | |
"mean 3.693530 5.407529 2.084039 \n", | |
"std 1.495112 0.532330 0.948211 \n", | |
"min 0.765100 3.485000 1.000000 \n", | |
"25% 2.600250 5.045000 1.000000 \n", | |
"50% 3.599000 5.226000 2.000000 \n", | |
"75% 4.768750 5.879000 3.000000 \n", | |
"max 8.456000 6.735000 5.439000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-eaa45a61-5d9e-4f32-ba5b-0f48135b3f04\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>area</th>\n", | |
" <th>perimeter</th>\n", | |
" <th>compactness</th>\n", | |
" <th>length</th>\n", | |
" <th>width</th>\n", | |
" <th>asymmetry</th>\n", | |
" <th>length_groove</th>\n", | |
" <th>class</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>210.000000</td>\n", | |
" <td>210.000000</td>\n", | |
" <td>207.000000</td>\n", | |
" <td>210.000000</td>\n", | |
" <td>209.000000</td>\n", | |
" <td>210.000000</td>\n", | |
" <td>206.000000</td>\n", | |
" <td>206.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>14.847524</td>\n", | |
" <td>14.559286</td>\n", | |
" <td>0.871280</td>\n", | |
" <td>5.563918</td>\n", | |
" <td>3.281440</td>\n", | |
" <td>3.693530</td>\n", | |
" <td>5.407529</td>\n", | |
" <td>2.084039</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.909699</td>\n", | |
" <td>1.305959</td>\n", | |
" <td>0.023306</td>\n", | |
" <td>0.719594</td>\n", | |
" <td>0.419907</td>\n", | |
" <td>1.495112</td>\n", | |
" <td>0.532330</td>\n", | |
" <td>0.948211</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>10.590000</td>\n", | |
" <td>12.410000</td>\n", | |
" <td>0.808100</td>\n", | |
" <td>0.818900</td>\n", | |
" <td>2.630000</td>\n", | |
" <td>0.765100</td>\n", | |
" <td>3.485000</td>\n", | |
" <td>1.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>12.270000</td>\n", | |
" <td>13.450000</td>\n", | |
" <td>0.857700</td>\n", | |
" <td>5.244750</td>\n", | |
" <td>2.956000</td>\n", | |
" <td>2.600250</td>\n", | |
" <td>5.045000</td>\n", | |
" <td>1.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>14.355000</td>\n", | |
" <td>14.320000</td>\n", | |
" <td>0.873500</td>\n", | |
" <td>5.518000</td>\n", | |
" <td>3.245000</td>\n", | |
" <td>3.599000</td>\n", | |
" <td>5.226000</td>\n", | |
" <td>2.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>17.305000</td>\n", | |
" <td>15.715000</td>\n", | |
" <td>0.887650</td>\n", | |
" <td>5.979750</td>\n", | |
" <td>3.566000</td>\n", | |
" <td>4.768750</td>\n", | |
" <td>5.879000</td>\n", | |
" <td>3.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>21.180000</td>\n", | |
" <td>17.250000</td>\n", | |
" <td>0.918300</td>\n", | |
" <td>6.675000</td>\n", | |
" <td>5.325000</td>\n", | |
" <td>8.456000</td>\n", | |
" <td>6.735000</td>\n", | |
" <td>5.439000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-eaa45a61-5d9e-4f32-ba5b-0f48135b3f04')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-eaa45a61-5d9e-4f32-ba5b-0f48135b3f04 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-eaa45a61-5d9e-4f32-ba5b-0f48135b3f04');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 126 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# set int type for dependent variable\n", | |
"DATA = DATA.dropna()\n", | |
"DATA[\"class\"] = DATA[\"class\"].astype('int32')\n", | |
"DATA\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 531 | |
}, | |
"id": "g_g_e98PsF-N", | |
"outputId": "e143c03d-31d6-4442-e735-b31ee3c36df7" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:3: SettingWithCopyWarning: \n", | |
"A value is trying to be set on a copy of a slice from a DataFrame.\n", | |
"Try using .loc[row_indexer,col_indexer] = value instead\n", | |
"\n", | |
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", | |
" This is separate from the ipykernel package so we can avoid doing imports until\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" area perimeter compactness length width asymmetry length_groove \\\n", | |
"0 15.26 14.84 0.8710 5.763 3.312 2.221 5.220 \n", | |
"1 14.88 14.57 0.8811 5.554 3.333 1.018 4.956 \n", | |
"2 14.29 14.09 0.9050 5.291 3.337 2.699 4.825 \n", | |
"3 13.84 13.94 0.8955 5.324 3.379 2.259 4.805 \n", | |
"4 16.14 14.99 0.9034 5.658 3.562 1.355 5.175 \n", | |
".. ... ... ... ... ... ... ... \n", | |
"205 12.19 13.20 0.8783 5.137 2.981 3.631 4.870 \n", | |
"206 11.23 12.88 0.8511 5.140 2.795 4.325 5.003 \n", | |
"207 13.20 13.66 0.8883 5.236 3.232 8.315 5.056 \n", | |
"208 11.84 13.21 0.8521 5.175 2.836 3.598 5.044 \n", | |
"209 12.30 13.34 0.8684 5.243 2.974 5.637 5.063 \n", | |
"\n", | |
" class \n", | |
"0 1 \n", | |
"1 1 \n", | |
"2 1 \n", | |
"3 1 \n", | |
"4 1 \n", | |
".. ... \n", | |
"205 3 \n", | |
"206 3 \n", | |
"207 3 \n", | |
"208 3 \n", | |
"209 3 \n", | |
"\n", | |
"[199 rows x 8 columns]" | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-d76e5518-c1e2-4ecc-872b-e7cf3ba929cc\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>area</th>\n", | |
" <th>perimeter</th>\n", | |
" <th>compactness</th>\n", | |
" <th>length</th>\n", | |
" <th>width</th>\n", | |
" <th>asymmetry</th>\n", | |
" <th>length_groove</th>\n", | |
" <th>class</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>15.26</td>\n", | |
" <td>14.84</td>\n", | |
" <td>0.8710</td>\n", | |
" <td>5.763</td>\n", | |
" <td>3.312</td>\n", | |
" <td>2.221</td>\n", | |
" <td>5.220</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>14.88</td>\n", | |
" <td>14.57</td>\n", | |
" <td>0.8811</td>\n", | |
" <td>5.554</td>\n", | |
" <td>3.333</td>\n", | |
" <td>1.018</td>\n", | |
" <td>4.956</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>14.29</td>\n", | |
" <td>14.09</td>\n", | |
" <td>0.9050</td>\n", | |
" <td>5.291</td>\n", | |
" <td>3.337</td>\n", | |
" <td>2.699</td>\n", | |
" <td>4.825</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>13.84</td>\n", | |
" <td>13.94</td>\n", | |
" <td>0.8955</td>\n", | |
" <td>5.324</td>\n", | |
" <td>3.379</td>\n", | |
" <td>2.259</td>\n", | |
" <td>4.805</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>16.14</td>\n", | |
" <td>14.99</td>\n", | |
" <td>0.9034</td>\n", | |
" <td>5.658</td>\n", | |
" <td>3.562</td>\n", | |
" <td>1.355</td>\n", | |
" <td>5.175</td>\n", | |
" <td>1</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>205</th>\n", | |
" <td>12.19</td>\n", | |
" <td>13.20</td>\n", | |
" <td>0.8783</td>\n", | |
" <td>5.137</td>\n", | |
" <td>2.981</td>\n", | |
" <td>3.631</td>\n", | |
" <td>4.870</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>206</th>\n", | |
" <td>11.23</td>\n", | |
" <td>12.88</td>\n", | |
" <td>0.8511</td>\n", | |
" <td>5.140</td>\n", | |
" <td>2.795</td>\n", | |
" <td>4.325</td>\n", | |
" <td>5.003</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>207</th>\n", | |
" <td>13.20</td>\n", | |
" <td>13.66</td>\n", | |
" <td>0.8883</td>\n", | |
" <td>5.236</td>\n", | |
" <td>3.232</td>\n", | |
" <td>8.315</td>\n", | |
" <td>5.056</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>208</th>\n", | |
" <td>11.84</td>\n", | |
" <td>13.21</td>\n", | |
" <td>0.8521</td>\n", | |
" <td>5.175</td>\n", | |
" <td>2.836</td>\n", | |
" <td>3.598</td>\n", | |
" <td>5.044</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>209</th>\n", | |
" <td>12.30</td>\n", | |
" <td>13.34</td>\n", | |
" <td>0.8684</td>\n", | |
" <td>5.243</td>\n", | |
" <td>2.974</td>\n", | |
" <td>5.637</td>\n", | |
" <td>5.063</td>\n", | |
" <td>3</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>199 rows × 8 columns</p>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d76e5518-c1e2-4ecc-872b-e7cf3ba929cc')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-d76e5518-c1e2-4ecc-872b-e7cf3ba929cc button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-d76e5518-c1e2-4ecc-872b-e7cf3ba929cc');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"DATA[\"class\"].hist(figsize=(14,10))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 610 | |
}, | |
"id": "bj-DuuZlpGp1", | |
"outputId": "70bdab88-c017-4cb6-f704-6724eef745cc" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<matplotlib.axes._subplots.AxesSubplot at 0x7f90b1d37d50>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1008x720 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# data preparation\n", | |
"X = DATA.loc[:, DATA.columns != 'class'].to_numpy()\n", | |
"y = DATA[\"class\"].to_numpy()\n", | |
"X, y" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "CTD6HwQUpOYi", | |
"outputId": "842a4e0c-0368-4314-f991-bd977e561732" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([[15.26 , 14.84 , 0.871 , ..., 3.312 , 2.221 , 5.22 ],\n", | |
" [14.88 , 14.57 , 0.8811, ..., 3.333 , 1.018 , 4.956 ],\n", | |
" [14.29 , 14.09 , 0.905 , ..., 3.337 , 2.699 , 4.825 ],\n", | |
" ...,\n", | |
" [13.2 , 13.66 , 0.8883, ..., 3.232 , 8.315 , 5.056 ],\n", | |
" [11.84 , 13.21 , 0.8521, ..., 2.836 , 3.598 , 5.044 ],\n", | |
" [12.3 , 13.34 , 0.8684, ..., 2.974 , 5.637 , 5.063 ]]),\n", | |
" array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3], dtype=int32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# split data\n", | |
"X_train, X_test, y_train, y_test = train_test_split(\n", | |
" X, y, test_size=0.33, random_state=0, shuffle=True)\n", | |
"X_train.shape, X_test.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "XvCQEKwJ0c42", | |
"outputId": "6a849e3c-cab0-4f94-c437-66ccf04e2790" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"((133, 7), (66, 7))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plt.hist(y_train)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 318 | |
}, | |
"id": "llsHJnoz1U_f", | |
"outputId": "d42a22e2-780f-4a67-b36e-81afa809226e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([44., 0., 0., 0., 0., 49., 0., 0., 0., 40.]),\n", | |
" array([1. , 1.2, 1.4, 1.6, 1.8, 2. , 2.2, 2.4, 2.6, 2.8, 3. ]),\n", | |
" <a list of 10 Patch objects>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOTUlEQVR4nO3dbYwd1X3H8e8vtgkp0BjHG9fCEFMFtYKqPGhFSUBVAk1DIamphBAoiqzKkqWWSkSt2pK8SJW0L+BNk1ZqFVkB1akSHkRCQTRJsQhR2tKYLATCUyiOAy0WYIeHAG2VyujfF3dclvWud3b3PnCS70da3Zkz5+78fXz889yZO/emqpAktectky5AkrQ8BrgkNcoAl6RGGeCS1CgDXJIatXqcO1u/fn1t3rx5nLuUpObdd999P6qqqbntYw3wzZs3MzMzM85dSlLzkjw1X7unUCSpUb2OwJM8CbwCvAYcrKrpJOuAm4DNwJPAZVX14mjKlCTNtZQj8PdX1RlVNd2tXw3cVVWnAHd165KkMVnJKZQtwM5ueSdwycrLkST11TfAC7gzyX1JtndtG6rqmW75WWDDfE9Msj3JTJKZAwcOrLBcSdIhfd+Fcl5V7UvyTmBXku/P3lhVlWTeT8Wqqh3ADoDp6Wk/OUuShqTXEXhV7ese9wO3AmcDzyXZCNA97h9VkZKkwy0a4EmOSXLcoWXgN4GHgduBrV23rcBtoypSknS4PqdQNgC3JjnU/0tV9fUk3wFuTrINeAq4bHRlSpLmWjTAq2ovcPo87c8DF4yiKGncNl/9jxPb95PXXDyxfatt3okpSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RG9flW+jeFSX3prF84K+nNyiNwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRvUO8CSrknw3yR3d+slJdifZk+SmJEeNrkxJ0lxLOQK/Cnhs1vq1wGeq6t3Ai8C2YRYmSTqyXgGeZBNwMfD5bj3A+cAtXZedwCWjKFCSNL++X+jwWeBPgOO69XcAL1XVwW79aeCE+Z6YZDuwHeCkk05afqWStEI/bV8Ms+gReJIPAfur6r7l7KCqdlTVdFVNT01NLedXSJLm0ecI/Fzgt5NcBBwN/DzwV8DaJKu7o/BNwL7RlSlJmmvRI/Cq+nhVbaqqzcDlwDeq6iPA3cClXbetwG0jq1KSdJiVvA/8T4E/TLKHwTnx64ZTkiSpjyV9K31VfRP4Zre8Fzh7+CVJkvrwTkxJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWrUogGe5Ogk9yZ5MMkjST7VtZ+cZHeSPUluSnLU6MuVJB3S5wj8J8D5VXU6cAZwYZJzgGuBz1TVu4EXgW2jK1OSNNeiAV4Dr3ara7qfAs4HbunadwKXjKRCSdK8ep0DT7IqyQPAfmAX8APgpao62HV5GjhhgeduTzKTZObAgQPDqFmSRM8Ar6rXquoMYBNwNvDLfXdQVTuqarqqpqemppZZpiRpriW9C6WqXgLuBt4DrE2yutu0Cdg35NokSUfQ510oU0nWdstvAz4APMYgyC/tum0FbhtVkZKkw61evAsbgZ1JVjEI/Jur6o4kjwI3JvkL4LvAdSOsU5I0x6IBXlXfA86cp30vg/PhkqQJ8E5MSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElq1KIBnuTEJHcneTTJI0mu6trXJdmV5Inu8fjRlytJOqTPEfhB4I+q6lTgHODKJKcCVwN3VdUpwF3duiRpTBYN8Kp6pqru75ZfAR4DTgC2ADu7bjuBS0ZVpCTpcEs6B55kM3AmsBvYUFXPdJueBTYs8JztSWaSzBw4cGAFpUqSZusd4EmOBb4MfKyqXp69raoKqPmeV1U7qmq6qqanpqZWVKwk6XW9AjzJGgbh/cWq+krX/FySjd32jcD+0ZQoSZpPn3ehBLgOeKyq/nLWptuBrd3yVuC24ZcnSVrI6h59zgU+CjyU5IGu7RPANcDNSbYBTwGXjaZESdJ8Fg3wqvoXIAtsvmC45UiS+vJOTElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJatSiAZ7k+iT7kzw8q21dkl1Jnugejx9tmZKkufocgf8dcOGctquBu6rqFOCubl2SNEaLBnhVfQt4YU7zFmBnt7wTuGTIdUmSFrHcc+AbquqZbvlZYMOQ6pEk9bTii5hVVUAttD3J9iQzSWYOHDiw0t1JkjrLDfDnkmwE6B73L9SxqnZU1XRVTU9NTS1zd5KkuZYb4LcDW7vlrcBtwylHktRXn7cR3gD8G/BLSZ5Osg24BvhAkieA3+jWJUljtHqxDlV1xQKbLhhyLZKkJfBOTElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNWlGAJ7kwyeNJ9iS5elhFSZIWt+wAT7IK+Bvgt4BTgSuSnDqswiRJR7aSI/CzgT1Vtbeq/he4EdgynLIkSYtZvYLnngD856z1p4Ffm9spyXZge7f6apLHl7m/9cCPlvncZcu1i3aZSF09WNfSTKyuReaY47U0b8q6cu2K63rXfI0rCfBeqmoHsGOlvyfJTFVND6GkobKupbGupbGupflZq2slp1D2ASfOWt/UtUmSxmAlAf4d4JQkJyc5CrgcuH04ZUmSFrPsUyhVdTDJHwD/BKwCrq+qR4ZW2eFWfBpmRKxraaxraaxraX6m6kpVjeL3SpJGzDsxJalRBrgkNWriAZ7k+iT7kzy8wPYk+evudv3vJTlr1ratSZ7ofraOua6PdPU8lOSeJKfP2vZk1/5Akpkx1/W+JD/u9v1Akk/O2jayjz7oUdcfz6rp4SSvJVnXbRvleJ2Y5O4kjyZ5JMlV8/QZ+xzrWdfY51jPusY+x3rWNfY5luToJPcmebCr61Pz9Hlrkpu6MdmdZPOsbR/v2h9P8sElF1BVE/0Bfh04C3h4ge0XAV8DApwD7O7a1wF7u8fju+Xjx1jXew/tj8HHCeyete1JYP2Exut9wB3ztK8CfgD8InAU8CBw6rjqmtP3w8A3xjReG4GzuuXjgH+f++eexBzrWdfY51jPusY+x/rUNYk51s2ZY7vlNcBu4Jw5fX4f+Fy3fDlwU7d8ajdGbwVO7sZu1VL2P/Ej8Kr6FvDCEbpsAb5QA98G1ibZCHwQ2FVVL1TVi8Au4MJx1VVV93T7Bfg2g/fBj1yP8VrISD/6YIl1XQHcMKx9H0lVPVNV93fLrwCPMbiLeLaxz7E+dU1ijvUcr4WMbI4to66xzLFuzrzara7pfua+M2QLsLNbvgW4IEm69hur6idV9UNgD4Mx7G3iAd7DfLfsn3CE9knYxuAI7pAC7kxyXwYfJTBu7+le0n0tyWld25tivJL8HIMQ/PKs5rGMV/fS9UwGR0mzTXSOHaGu2cY+xxapa2JzbLHxGvccS7IqyQPAfgb/4S84v6rqIPBj4B0MYbxGfiv9T7sk72fwj+u8Wc3nVdW+JO8EdiX5fneEOg73A++qqleTXAT8A3DKmPbdx4eBf62q2UfrIx+vJMcy+Af9sap6eZi/eyX61DWJObZIXRObYz3/Hsc6x6rqNeCMJGuBW5P8SlXNey1o2Fo4Al/olv2J38qf5FeBzwNbqur5Q+1Vta973A/cyhJfFq1EVb186CVdVX0VWJNkPW+C8epczpyXtqMeryRrGPyj/2JVfWWeLhOZYz3qmsgcW6yuSc2xPuPVGfsc6373S8DdHH6a7f/HJclq4O3A8wxjvIZ9Un85P8BmFr4odzFvvMB0b9e+Dvghg4tLx3fL68ZY10kMzlm9d077McBxs5bvAS4cY12/wOs3aJ0N/Ec3dqsZXIQ7mdcvMJ02rrq67W9ncJ78mHGNV/dn/wLw2SP0Gfsc61nX2OdYz7rGPsf61DWJOQZMAWu75bcB/wx8aE6fK3njRcybu+XTeONFzL0s8SLmxE+hJLmBwVXt9UmeBv6MwYUAqupzwFcZvEtgD/DfwO92215I8ucMPpMF4NP1xpdMo67rkwzOY/3t4HoEB2vwaWMbGLyMgsGE/lJVfX2MdV0K/F6Sg8D/AJfXYLaM9KMPetQF8DvAnVX1X7OeOtLxAs4FPgo81J2nBPgEg3Cc5BzrU9ck5lifuiYxx/rUBeOfYxuBnRl8wc1bGITzHUk+DcxU1e3AdcDfJ9nD4D+Xy7uaH0lyM/AocBC4sganY3rzVnpJalQL58AlSfMwwCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1Kj/g/YevG17WakAgAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"plt.hist(y_test)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 318 | |
}, | |
"id": "X6qMKxiy1ZO4", | |
"outputId": "46a1f5e8-cc14-4f19-e6c7-74309f71cf88" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([22., 0., 0., 0., 0., 19., 0., 0., 0., 25.]),\n", | |
" array([1. , 1.2, 1.4, 1.6, 1.8, 2. , 2.2, 2.4, 2.6, 2.8, 3. ]),\n", | |
" <a list of 10 Patch objects>)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAN0ElEQVR4nO3df6zd9V3H8edrtJsKZLT2WhuEXVyISTGukBvEQRYW1DFw6UiMgRhsDEkXhQSSxaTyxzbnPyxxmzHRLZ2QdYaxEQeODDZpkAQnrvOWdKOACLJOaTpaxhygRlP29o/zrRwu9/ace88vPsvzkZyc7/l8v+d83/30c1/3ez7f8z03VYUkqT1vmnUBkqS1McAlqVEGuCQ1ygCXpEYZ4JLUqHXT3NmmTZtqfn5+mruUpObt37//+aqaW9o+1QCfn59ncXFxmruUpOYl+e5y7U6hSFKjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYNDPAkZyV5MMnjSR5LcmPX/pEkh5Mc6G5XTL5cSdIJw3wO/Djwwap6JMnpwP4ke7t1n6yqP5lceZKklQwM8Ko6Ahzpll9K8gRw5qQLkySd3KquxEwyD5wP7AMuBm5I8jvAIr2j9B8s85ydwE6As88+e8RyJWnt5nfdO7N9H7rlyrG/5tAnMZOcBnwJuKmqXgQ+Bbwd2EbvCP3jyz2vqnZX1UJVLczNve5SfknSGg0V4EnW0wvv26vqLoCqeq6qXqmqHwGfAS6cXJmSpKWG+RRKgFuBJ6rqE33tW/o2uwo4OP7yJEkrGWYO/GLgWuDRJAe6tpuBa5JsAwo4BHxgIhVKkpY1zKdQvg5kmVX3jb8cSdKwvBJTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJatQwf9T4DWF+170z2/ehW66c2b4laSUegUtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQMDPMlZSR5M8niSx5Lc2LVvTLI3yVPd/YbJlytJOmGYI/DjwAeraitwEXB9kq3ALuCBqjoXeKB7LEmakoEBXlVHquqRbvkl4AngTGA7sKfbbA/w/kkVKUl6vVXNgSeZB84H9gGbq+pIt+p7wOYVnrMzyWKSxWPHjo1QqiSp39ABnuQ04EvATVX1Yv+6qiqglnteVe2uqoWqWpibmxupWEnSq4YK8CTr6YX37VV1V9f8XJIt3fotwNHJlChJWs4wn0IJcCvwRFV9om/VPcCObnkH8OXxlydJWskwfxPzYuBa4NEkB7q2m4FbgDuTXAd8F/ityZQoSVrOwACvqq8DWWH1ZeMtR5I0LK/ElKRGGeCS1Khh5sClH3vzu+6d2b4P3XLlzPattnkELkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUqIEBnuS2JEeTHOxr+0iSw0kOdLcrJlumJGmpYY7APwtcvkz7J6tqW3e7b7xlSZIGGRjgVfUQ8MIUapEkrcIoc+A3JPl2N8WyYWwVSZKGstYA/xTwdmAbcAT4+EobJtmZZDHJ4rFjx9a4O0nSUmsK8Kp6rqpeqaofAZ8BLjzJtruraqGqFubm5tZapyRpiTUFeJItfQ+vAg6utK0kaTLWDdogyR3ApcCmJM8CHwYuTbINKOAQ8IEJ1ihJWsbAAK+qa5ZpvnUCtUiSVsErMSWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEDAzzJbUmOJjnY17Yxyd4kT3X3GyZbpiRpqWGOwD8LXL6kbRfwQFWdCzzQPZYkTdHAAK+qh4AXljRvB/Z0y3uA94+5LknSAGudA99cVUe65e8Bm1faMMnOJItJFo8dO7bG3UmSlhr5JGZVFVAnWb+7qhaqamFubm7U3UmSOmsN8OeSbAHo7o+OryRJ0jDWGuD3ADu65R3Al8dTjiRpWMN8jPAO4B+BX0jybJLrgFuAX0vyFPCr3WNJ0hStG7RBVV2zwqrLxlyLJGkVvBJTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1Kh1ozw5ySHgJeAV4HhVLYyjKEnSYCMFeOfdVfX8GF5HkrQKTqFIUqNGDfAC7k+yP8nO5TZIsjPJYpLFY8eOjbg7SdIJowb4JVV1AfBe4Pok71q6QVXtrqqFqlqYm5sbcXeSpBNGCvCqOtzdHwXuBi4cR1GSpMHWHOBJTk1y+oll4NeBg+MqTJJ0cqN8CmUzcHeSE6/z+ar62liqkiQNtOYAr6pngHeMsRZJ0ir4MUJJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWrUSAGe5PIkTyZ5OsmucRUlSRpszQGe5BTgz4H3AluBa5JsHVdhkqSTG+UI/ELg6ap6pqr+F/gCsH08ZUmSBlk3wnPPBP697/GzwC8v3SjJTmBn9/DlJE+ucX+bgOfX+NyR5GMnXT2zugawrtVxfK2Oda1SPjZSbW9brnGUAB9KVe0Gdo/6OkkWq2phDCWNlXWtjnWtjnWtzhu1LphMbaNMoRwGzup7/HNdmyRpCkYJ8H8Czk1yTpI3A1cD94ynLEnSIGueQqmq40luAP4WOAW4raoeG1tlrzfyNMyEWNfqWNfqWNfqvFHrggnUlqoa92tKkqbAKzElqVEGuCQ1auYBnuS2JEeTHFxhfZL8WXe5/reTXNC3bkeSp7rbjinX9dtdPY8meTjJO/rWHeraDyRZnHJdlyb5YbfvA0k+1LduYl99MERdf9BX08EkryTZ2K2bZH+dleTBJI8neSzJjctsM/UxNmRdUx9jQ9Y19TE2ZF1TH2NJfiLJN5N8q6vrj5bZ5i1Jvtj1yb4k833r/rBrfzLJe1ZdQFXN9Aa8C7gAOLjC+iuArwIBLgL2de0bgWe6+w3d8oYp1vXOE/uj93UC+/rWHQI2zai/LgW+skz7KcC/Aj8PvBn4FrB1WnUt2fZ9wN9Nqb+2ABd0y6cD/7L03z2LMTZkXVMfY0PWNfUxNkxdsxhj3Zg5rVteD+wDLlqyze8Dn+6Wrwa+2C1v7froLcA5Xd+dspr9z/wIvKoeAl44ySbbgc9VzzeAM5JsAd4D7K2qF6rqB8Be4PJp1VVVD3f7BfgGvc/BT9wQ/bWSiX71wSrruga4Y1z7PpmqOlJVj3TLLwFP0LuKuN/Ux9gwdc1ijA3ZXyuZ2BhbQ11TGWPdmHm5e7i+uy39ZMh2YE+3/NfAZUnStX+hqv6nqr4DPE2vD4c28wAfwnKX7J95kvZZuI7eEdwJBdyfZH96XyUwbb/SvaX7apLzurY3RH8l+Sl6Ifilvuap9Ff31vV8ekdJ/WY6xk5SV7+pj7EBdc1sjA3qr2mPsSSnJDkAHKX3C3/F8VVVx4EfAj/NGPpr4pfS/7hL8m56P1yX9DVfUlWHk/wMsDfJP3dHqNPwCPC2qno5yRXA3wDnTmnfw3gf8A9V1X+0PvH+SnIavR/om6rqxXG+9iiGqWsWY2xAXTMbY0P+P051jFXVK8C2JGcAdyf5xapa9lzQuLVwBL7SJfszv5Q/yS8Bfwlsr6rvn2ivqsPd/VHgblb5tmgUVfXiibd0VXUfsD7JJt4A/dW5miVvbSfdX0nW0/uhv72q7lpmk5mMsSHqmskYG1TXrMbYMP3VmfoY6177P4AHef002//3S5J1wFuB7zOO/hr3pP5absA8K5+Uu5LXnmD6Zte+EfgOvZNLG7rljVOs62x6c1bvXNJ+KnB63/LDwOVTrOtnefUCrQuBf+v6bh29k3Dn8OoJpvOmVVe3/q305slPnVZ/df/2zwF/epJtpj7Ghqxr6mNsyLqmPsaGqWsWYwyYA87oln8S+HvgN5Zscz2vPYl5Z7d8Hq89ifkMqzyJOfMplCR30DurvSnJs8CH6Z0IoKo+DdxH71MCTwP/Bfxut+6FJH9M7ztZAD5ar33LNOm6PkRvHusveucjOF69bxrbTO9tFPQG9Oer6mtTrOs3gd9Lchz4b+Dq6o2WiX71wRB1AVwF3F9V/9n31In2F3AxcC3waDdPCXAzvXCc5Rgbpq5ZjLFh6prFGBumLpj+GNsC7EnvD9y8iV44fyXJR4HFqroHuBX4qyRP0/vlcnVX82NJ7gQeB44D11dvOmZoXkovSY1qYQ5ckrQMA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ16v8AuLTKyeyYAKYAAAAASUVORK5CYII=\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Логистическая регрессия для трех и более классов\n", | |
"\n", | |
"Пусть нам даны пары наблюдений $(Y_1,X_1),...,(Y_n,X_n)$, где $Y_i$ представленна в виде $k$ категорий/классов $\\{1,...,K\\}$, а $X \\in R^d$.\n", | |
"\n", | |
"## Вариант 1 - обобщённая бинарная модель (multinomial logit)\n", | |
"\n", | |
"Модель логистической регрессия применяется для моделирования вероятностей возникновения $K$ взаимоисключающих классов через линейную завистмость от $X$.\n", | |
"\n", | |
"Обобщённый вартант логистической регрессии для $K>2$ классов это мультиномиальная логистическая регрессия.\n", | |
"\n", | |
"$\\mathrm{log}\\frac{\\mathbb{P}(Y=1|X)}{\\mathbb{P}(Y=K|X)}=\\beta_{10}+\\beta_1^TX$\n", | |
"\n", | |
"$\\mathrm{log}\\frac{\\mathbb{P}(Y=2|X)}{\\mathbb{P}(Y=K|X)}=\\beta_{20}+\\beta_2^TX$\n", | |
"\n", | |
"...\n", | |
"\n", | |
"$\\mathrm{log}\\frac{\\mathbb{P}(Y=K-1|X)}{\\mathbb{P}(Y=K|X)}=\\beta_{(K-1)0}+\\beta_{K-1}^TX$\n", | |
"\n", | |
"То есть, модель задаёт $K-1$ уравнений log отношения шансов к $K$ классу в условиях суммы вероятностей возникновения $K$ классов равной 1, то есть:\n", | |
"\n", | |
"$\\mathbb{P}(Y=k|X)=\\frac{\\exp{\\beta_{k0}+\\beta_k^TX}}{1+\\sum_{i=1}^{K-1}\\exp{\\beta_{i0}+\\beta_i^TX}}$, где $k=1,...,K-1$\n", | |
"\n", | |
"$\\mathbb{P}(Y=K|X)=\\frac{1}{1+\\sum_{i=1}^{K-1}\\exp{\\beta_{i0}+\\beta_i^TX}}$.\n", | |
"\n", | |
"Лог-функция максимального правдоподобия **(log-likelihood)** имеет вид:\n", | |
"\n", | |
"$$\\ell(\\theta) = \\Sigma_{i=1}^N \\log p_{gi}(X_i,\\theta)$$\n", | |
"\n", | |
"где $p_k(X_i;\\theta)=\\mathbb{P}(Y=k|X;\\theta)$\n", | |
"\n", | |
"Решения в общем виде не существует, значения $\\theta$ минимизирующие $\\ell(\\theta)$ могут быть получены методами численной оптимизации.\n", | |
"\n", | |
"__Мультиномиальная логистическая регрессия предполагает__ независимость нерелеватных вариантов [https://en.wikipedia.org/wiki/Independence_of_irrelevant_alternatives](https://en.wikipedia.org/wiki/Independence_of_irrelevant_alternatives)\n", | |
"\n", | |
"## Вариант 2 - One vs Rest (OVR)\n", | |
"\n", | |
"OVR это обобщённый метод раширения любого бинарного классификатора для $K>3$ классов. По сути, для каждого отдельного класса, задача сводится к бинарной, то есть для кажого $k$ из $1,...,K$ конструируется новая зависимая переменная $z$, такая что:\n", | |
"\n", | |
"$z_i=1$, если $Y_i=k$\n", | |
"\n", | |
"$z_i=0$, если $Y_i \\neq k$\n", | |
"\n", | |
"Далее, для каждого класса подбираются оптимальные коэффициенты на основе бинарного классификатора (в нашем случае это логистическая регрессия).\n", | |
"\n", | |
"Формируется функция $f_k$ отражающая набор $K$ класисфикаторов (например, $K$ бинарных логистических регрессий). \n", | |
"\n", | |
"Тогда, пронозное значение может быть получено: $\\hat{y}=\\mathrm{argmax}_{k \\in 1...K}f_k(X)$\n", | |
"\n", | |
"## Вартант 3 - One Vs One (OVO)\n", | |
"\n", | |
"В этом подходе, задача сводится к бинарной путём формирования группу классов \"каждый с каждым\", таким образом, получается $\\frac{K(K-1)}{2}$ классификаторов. На этапе вывода, аналогично OVR, прогнозное значение присваивается класисифкатору с наибольшей вероятностью.\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "F9J2hxvtMlqK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# logistic regression\n", | |
"# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n", | |
"clf = LogisticRegression(penalty='none', multi_class='ovr', solver='newton-cg')\n", | |
"clf.fit(X_train,y_train)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "iok-Q9oTsnKP", | |
"outputId": "add021b7-4157-43e4-ff4a-4f0c16bbfc89" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"LogisticRegression(multi_class='ovr', penalty='none', solver='newton-cg')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Оценка производительности модели" | |
], | |
"metadata": { | |
"id": "Vca6QMOhknMn" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Бинарный случай" | |
], | |
"metadata": { | |
"id": "AVjffmCclo86" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"\n", | |
"\n", | |
"Fawcett, Tom. \"An introduction to ROC analysis.\" Pattern recognition letters 27.8 (2006): 861-874.\n" | |
], | |
"metadata": { | |
"id": "QOokyUwk0jU7" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## K>3" | |
], | |
"metadata": { | |
"id": "a_Oeji3HlvZR" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"$\\mathrm{precision}_k=\\frac{TP_k}{TP_k+\\sum_{i=1}^K FP_{k \\rightarrow i}}$\n", | |
"\n", | |
"$\\mathrm{recall}_k=\\frac{TP_k}{TP_k+\\sum_{i=1}^K FN_{k \\rightarrow i}}$\n", | |
"\n", | |
"$\\mathrm{tpr}_k=\\frac{TP_K}{TP_k+\\sum_{i=1}^K FN_{k \\rightarrow i}}$\n", | |
"\n", | |
"$\\mathrm{fpr}_k=\\frac{FP_K}{|Y \\neq k|}$\n", | |
"\n", | |
"$\\mathrm{macro\\ precision}=\\frac{\\sum_{i=1}^K\\mathrm{precision}_K}{K}$\n", | |
"\n", | |
"$\\mathrm{micro\\ precision}= \\frac{\\sum_{i=1}^K TP_i}{\\sum_{i=1}^K TP_i+\\sum_{i=1}^K\\sum FP_i}$\n", | |
"\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "fXJLbaPvlzeO" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Model performance metrics\n", | |
"y_true = y_test\n", | |
"y_pred = clf.predict(X_test)\n", | |
"# precision, recall, f1, overall accuracy\n", | |
"print(classification_report(y_true,y_pred))\n", | |
"\n", | |
"y_true, y_pred" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "tY6FIRbZxcYX", | |
"outputId": "9cf19d6d-a8e6-4092-ac46-f3cbed66f453" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 1 1.00 0.91 0.95 22\n", | |
" 2 1.00 0.95 0.97 19\n", | |
" 3 0.89 1.00 0.94 25\n", | |
"\n", | |
" accuracy 0.95 66\n", | |
" macro avg 0.96 0.95 0.96 66\n", | |
"weighted avg 0.96 0.95 0.95 66\n", | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([1, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 2, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 1, 3, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32),\n", | |
" array([3, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 3, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 3, 3, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# confusion matrix for the best threshold\n", | |
"confusion_matrix(y_pred, y_true)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "viMOSJh-zQdv", | |
"outputId": "fe332d4c-3199-438e-a598-e779ee4ab383" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[20, 0, 0],\n", | |
" [ 0, 18, 0],\n", | |
" [ 2, 1, 25]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"precision_score(y_true, y_pred, average='micro')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "eYrVQrBAqbvn", | |
"outputId": "7712e062-87a6-4057-e0ee-ca799a1de7a2" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9545454545454546" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"precision_score(y_true, y_pred, average='macro')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Bx0O4BIIqkSF", | |
"outputId": "47643920-fdf1-4107-8f20-2bcca6cd89f3" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9642857142857143" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"y_score = clf.predict_proba(X_test)\n", | |
"y_score" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LbSG6zvBtDek", | |
"outputId": "b08d80c4-2607-4a76-8cb7-d0a31abba47e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[2.28972552e-001, 8.39075373e-020, 7.71027448e-001],\n", | |
" [3.43475032e-007, 1.69079145e-014, 9.99999657e-001],\n", | |
" [2.75974676e-005, 9.99972403e-001, 3.51162994e-169],\n", | |
" [1.58112169e-009, 9.99999998e-001, 1.03042150e-129],\n", | |
" [4.91270080e-006, 7.71746322e-019, 9.99995087e-001],\n", | |
" [1.72650019e-006, 1.41571111e-013, 9.99998273e-001],\n", | |
" [1.00000000e+000, 5.75086375e-021, 7.95483592e-084],\n", | |
" [1.64773367e-014, 4.52900586e-008, 9.99999955e-001],\n", | |
" [1.00000000e+000, 1.77712925e-031, 2.62594104e-098],\n", | |
" [1.19169001e-008, 7.30929275e-006, 9.99992679e-001],\n", | |
" [1.00000000e+000, 2.24123006e-034, 2.90202093e-089],\n", | |
" [3.74722376e-009, 9.99999996e-001, 5.19403888e-055],\n", | |
" [1.36196231e-011, 6.55523544e-009, 9.99999993e-001],\n", | |
" [1.76699473e-010, 2.06544627e-007, 9.99999793e-001],\n", | |
" [2.16016691e-007, 9.99999784e-001, 1.22329341e-168],\n", | |
" [7.36151405e-001, 2.63848595e-001, 9.03610140e-093],\n", | |
" [1.00000000e+000, 4.37951222e-017, 5.83057434e-078],\n", | |
" [4.50412842e-010, 5.00000000e-001, 5.00000000e-001],\n", | |
" [9.22469156e-001, 1.72575458e-009, 7.75308420e-002],\n", | |
" [5.46144888e-009, 9.99999995e-001, 1.12391788e-131],\n", | |
" [3.17234026e-003, 2.85212171e-015, 9.96827660e-001],\n", | |
" [4.73558633e-011, 3.76867641e-006, 9.99996231e-001],\n", | |
" [1.00000000e+000, 2.40780612e-016, 4.54270534e-069],\n", | |
" [7.40021255e-014, 2.00615585e-005, 9.99979938e-001],\n", | |
" [9.95507827e-001, 2.13296334e-017, 4.49217338e-003],\n", | |
" [5.76144721e-006, 9.99994239e-001, 6.87045991e-128],\n", | |
" [5.36476396e-008, 4.29945746e-019, 9.99999946e-001],\n", | |
" [1.79234723e-010, 5.13039854e-017, 1.00000000e+000],\n", | |
" [1.92081625e-006, 9.99998079e-001, 2.68782561e-128],\n", | |
" [1.00000000e+000, 4.73836907e-027, 1.18295040e-146],\n", | |
" [1.00000000e+000, 7.04778950e-019, 4.45722580e-137],\n", | |
" [1.00000000e+000, 1.26167039e-014, 2.45255932e-084],\n", | |
" [2.89390984e-009, 9.99999997e-001, 1.01816448e-099],\n", | |
" [4.93370281e-002, 9.50662972e-001, 6.69597067e-209],\n", | |
" [9.32990414e-002, 3.40566906e-020, 9.06700959e-001],\n", | |
" [1.02520226e-001, 1.16442761e-011, 8.97479774e-001],\n", | |
" [1.98863317e-008, 7.87527503e-003, 9.92124705e-001],\n", | |
" [1.00000000e+000, 2.65228151e-018, 5.74834249e-147],\n", | |
" [7.25070885e-010, 3.57432611e-006, 9.99996425e-001],\n", | |
" [2.90744068e-004, 5.46738088e-016, 9.99709256e-001],\n", | |
" [9.49038371e-012, 3.10919628e-007, 9.99999689e-001],\n", | |
" [1.67326130e-006, 9.99998327e-001, 1.29953842e-070],\n", | |
" [1.00000000e+000, 7.57301928e-034, 9.52198152e-075],\n", | |
" [4.33637871e-004, 9.99566362e-001, 1.42245387e-037],\n", | |
" [9.70380910e-009, 4.35432677e-008, 9.99999947e-001],\n", | |
" [1.00000000e+000, 1.18748137e-027, 5.24436451e-120],\n", | |
" [9.99999999e-001, 6.80830071e-010, 2.67382696e-127],\n", | |
" [3.79666874e-005, 9.99962033e-001, 3.95139006e-093],\n", | |
" [1.14489614e-011, 1.00000000e+000, 5.67720040e-131],\n", | |
" [1.09683226e-013, 2.28316978e-012, 1.00000000e+000],\n", | |
" [4.21498605e-007, 9.99999579e-001, 8.37169587e-041],\n", | |
" [3.21654366e-009, 1.32685082e-003, 9.98673146e-001],\n", | |
" [6.83514839e-007, 9.99999316e-001, 1.18329843e-224],\n", | |
" [6.25196037e-007, 9.99999375e-001, 7.81430666e-026],\n", | |
" [2.12543529e-012, 6.67308916e-010, 9.99999999e-001],\n", | |
" [1.00000000e+000, 9.72878131e-020, 5.69707414e-196],\n", | |
" [1.00000000e+000, 3.22446057e-018, 4.20152706e-096],\n", | |
" [7.13452624e-004, 9.99286547e-001, 1.35074386e-185],\n", | |
" [1.00000000e+000, 2.57762112e-024, 5.52655892e-043],\n", | |
" [1.00000000e+000, 3.52168085e-019, 6.71078451e-099],\n", | |
" [7.00190201e-004, 9.99299810e-001, 3.14739566e-083],\n", | |
" [2.99078090e-007, 7.04513952e-018, 9.99999701e-001],\n", | |
" [2.80203168e-009, 2.18169543e-020, 9.99999997e-001],\n", | |
" [1.00000000e+000, 1.47261853e-023, 7.44901568e-077],\n", | |
" [8.67660479e-008, 1.85651340e-014, 9.99999913e-001],\n", | |
" [1.96150822e-008, 3.41343291e-001, 6.58656689e-001]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 15 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Binarize the output\n", | |
"y = label_binarize(y_test, classes=[1, 2, 3])\n", | |
"n_classes = y.shape[1]\n", | |
"\n", | |
"y" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mSjZx9xwv-UQ", | |
"outputId": "b8ea6785-dc4c-4951-b9ed-8a43e1fb3de6" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 1, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [1, 0, 0],\n", | |
" [1, 0, 0],\n", | |
" [0, 1, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1],\n", | |
" [1, 0, 0],\n", | |
" [0, 0, 1],\n", | |
" [0, 0, 1]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Compute ROC curve and ROC area for each class\n", | |
"fpr = dict()\n", | |
"tpr = dict()\n", | |
"roc_auc = dict()\n", | |
"for i in range(n_classes):\n", | |
" fpr[i], tpr[i], _ = roc_curve(y[:, i], y_score[:, i])\n", | |
" roc_auc[i] = auc(fpr[i], tpr[i])\n" | |
], | |
"metadata": { | |
"id": "KKNGX8eaycQs" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Compute micro-average ROC curve and ROC area\n", | |
"fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y.ravel(), y_score.ravel())\n", | |
"roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n", | |
"\n", | |
"fpr,tpr" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "xYFSM-ZXzn8z", | |
"outputId": "4b129a8a-69a9-4c23-a93a-00ea7963a390" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"({0: array([0. , 0. , 0. , 0.02272727, 0.02272727,\n", | |
" 1. ]),\n", | |
" 1: array([0., 0., 0., 1.]),\n", | |
" 2: array([0. , 0. , 0. , 0.02439024, 0.02439024,\n", | |
" 0.04878049, 0.04878049, 1. ]),\n", | |
" 'micro': array([0. , 0. , 0. , 0.00757576, 0.00757576,\n", | |
" 0.01515152, 0.01515152, 0.02272727, 0.02272727, 0.03787879,\n", | |
" 0.03787879, 0.04545455, 0.04545455, 1. ])},\n", | |
" {0: array([0. , 0.63636364, 0.95454545, 0.95454545, 1. ,\n", | |
" 1. ]),\n", | |
" 1: array([0. , 0.05263158, 1. , 1. ]),\n", | |
" 2: array([0. , 0.04, 0.92, 0.92, 0.96, 0.96, 1. , 1. ]),\n", | |
" 'micro': array([0. , 0.21212121, 0.90909091, 0.90909091, 0.92424242,\n", | |
" 0.92424242, 0.95454545, 0.95454545, 0.96969697, 0.96969697,\n", | |
" 0.98484848, 0.98484848, 1. , 1. ])})" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"roc_auc" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "217DZZGj1PMj", | |
"outputId": "c0693d0f-f27f-4dec-acc4-b4be8e0a72fe" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: 0.9989669421487604,\n", | |
" 1: 1.0,\n", | |
" 2: 0.9970731707317073,\n", | |
" 'micro': 0.9978191000918274}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# source https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html\n", | |
"\n", | |
"lw = 2\n", | |
"\n", | |
"# First aggregate all false positive rates\n", | |
"all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))\n", | |
"\n", | |
"# Then interpolate all ROC curves at this points\n", | |
"mean_tpr = np.zeros_like(all_fpr)\n", | |
"for i in range(n_classes):\n", | |
" mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n", | |
"\n", | |
"# Finally average it and compute AUC\n", | |
"mean_tpr /= n_classes\n", | |
"\n", | |
"fpr[\"macro\"] = all_fpr\n", | |
"tpr[\"macro\"] = mean_tpr\n", | |
"roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n", | |
"\n", | |
"# Plot all ROC curves\n", | |
"plt.figure(figsize=(14, 10))\n", | |
"plt.plot(\n", | |
" fpr[\"micro\"],\n", | |
" tpr[\"micro\"],\n", | |
" label=\"micro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"micro\"]),\n", | |
" color=\"deeppink\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"plt.plot(\n", | |
" fpr[\"macro\"],\n", | |
" tpr[\"macro\"],\n", | |
" label=\"macro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"macro\"]),\n", | |
" color=\"navy\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"colors = cycle([\"aqua\", \"darkorange\", \"cornflowerblue\"])\n", | |
"for i, color in zip(range(n_classes), colors):\n", | |
" plt.plot(\n", | |
" fpr[i],\n", | |
" tpr[i],\n", | |
" color=color,\n", | |
" lw=lw,\n", | |
" label=\"ROC curve of class {0} (area = {1:0.4f})\".format(i, roc_auc[i]),\n", | |
" )\n", | |
"\n", | |
"plt.plot([0, 1], [0, 1], \"k--\", lw=lw)\n", | |
"plt.xlim([0.0, 1.0])\n", | |
"plt.ylim([0.0, 1.05])\n", | |
"plt.xlabel(\"False Positive Rate\")\n", | |
"plt.ylabel(\"True Positive Rate\")\n", | |
"plt.title(\"Some extension of Receiver operating characteristic to multiclass\")\n", | |
"plt.legend(loc=\"lower right\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 621 | |
}, | |
"id": "uOGMsA0o0DdV", | |
"outputId": "e8b9c007-a512-42c8-db9c-18fd78aa6b11" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1008x720 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"roc_auc_score(y_true, clf.predict_proba(X_test), multi_class='ovr')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rnlbHMGwtNRd", | |
"outputId": "d9ee75a3-e5a5-4538-8544-1575c6c511e8" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9986800376268224" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 145 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Линейный/квадратный дискриминативный анализ" | |
], | |
"metadata": { | |
"id": "0aFi7zTd1wxw" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"https://scikit-learn.org/stable/modules/lda_qda.html#lda-qda" | |
], | |
"metadata": { | |
"id": "17lpONwUFfL6" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#LDA \n", | |
"clf_lda = LinearDiscriminantAnalysis()\n", | |
"clf_lda.fit(X_train,y_train)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QrGXgkVT13jA", | |
"outputId": "0fa0f69a-adac-4f19-cb4e-988f70f1b3d9" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"LinearDiscriminantAnalysis()" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Model performance metrics\n", | |
"y_true = y_test\n", | |
"y_pred = clf_lda.predict(X_test)\n", | |
"# precision, recall, f1, overall accuracy\n", | |
"print(classification_report(y_true,y_pred))\n", | |
"\n", | |
"y_true, y_pred" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "9CXdp8I00ivu", | |
"outputId": "6c8c62c1-204b-45bd-bf63-c52c76a6eb3d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 1 0.95 0.91 0.93 22\n", | |
" 2 0.95 1.00 0.97 19\n", | |
" 3 0.96 0.96 0.96 25\n", | |
"\n", | |
" accuracy 0.95 66\n", | |
" macro avg 0.95 0.96 0.95 66\n", | |
"weighted avg 0.95 0.95 0.95 66\n", | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([1, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 2, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 1, 3, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32),\n", | |
" array([1, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 2, 1, 2, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 3, 1, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"y_score = clf_lda.predict_proba(X_test)\n", | |
"\n", | |
"# Compute ROC curve and ROC area for each class\n", | |
"fpr = dict()\n", | |
"tpr = dict()\n", | |
"roc_auc = dict()\n", | |
"for i in range(n_classes):\n", | |
" fpr[i], tpr[i], _ = roc_curve(y[:, i], y_score[:, i])\n", | |
" roc_auc[i] = auc(fpr[i], tpr[i])\n", | |
"\n", | |
"# Compute micro-average ROC curve and ROC area\n", | |
"fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y.ravel(), y_score.ravel())\n", | |
"roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n", | |
"\n", | |
"lw = 2\n", | |
"\n", | |
"# First aggregate all false positive rates\n", | |
"all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))\n", | |
"\n", | |
"# Then interpolate all ROC curves at this points\n", | |
"mean_tpr = np.zeros_like(all_fpr)\n", | |
"for i in range(n_classes):\n", | |
" mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n", | |
"\n", | |
"# Finally average it and compute AUC\n", | |
"mean_tpr /= n_classes\n", | |
"\n", | |
"fpr[\"macro\"] = all_fpr\n", | |
"tpr[\"macro\"] = mean_tpr\n", | |
"roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n", | |
"\n", | |
"# Plot all ROC curves\n", | |
"plt.figure(figsize=(14, 10))\n", | |
"plt.plot(\n", | |
" fpr[\"micro\"],\n", | |
" tpr[\"micro\"],\n", | |
" label=\"micro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"micro\"]),\n", | |
" color=\"deeppink\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"plt.plot(\n", | |
" fpr[\"macro\"],\n", | |
" tpr[\"macro\"],\n", | |
" label=\"macro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"macro\"]),\n", | |
" color=\"navy\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"colors = cycle([\"aqua\", \"darkorange\", \"cornflowerblue\"])\n", | |
"for i, color in zip(range(n_classes), colors):\n", | |
" plt.plot(\n", | |
" fpr[i],\n", | |
" tpr[i],\n", | |
" color=color,\n", | |
" lw=lw,\n", | |
" label=\"ROC curve of class {0} (area = {1:0.4f})\".format(i, roc_auc[i]),\n", | |
" )\n", | |
"\n", | |
"plt.plot([0, 1], [0, 1], \"k--\", lw=lw)\n", | |
"plt.xlim([0.0, 1.0])\n", | |
"plt.ylim([0.0, 1.05])\n", | |
"plt.xlabel(\"False Positive Rate\")\n", | |
"plt.ylabel(\"True Positive Rate\")\n", | |
"plt.title(\"Some extension of Receiver operating characteristic to multiclass\")\n", | |
"plt.legend(loc=\"lower right\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 621 | |
}, | |
"id": "CVYA0dFbAqPt", | |
"outputId": "ece28a01-2e3d-40fc-c402-8a9b787c0227" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1008x720 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"roc_auc_score(y_true, clf_lda.predict_proba(X_test), multi_class='ovr')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "hR-vZGLr-Av0", | |
"outputId": "51cabe1a-7fa8-4d13-e61d-4e15ac78f9eb" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9983165356446952" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 24 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#QDA \n", | |
"clf_qda = QuadraticDiscriminantAnalysis()\n", | |
"clf_qda.fit(X_train,y_train)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "DkUyDyTN2t1n", | |
"outputId": "e44ab814-2aa5-45c6-eb7f-2b2ac6693580" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"QuadraticDiscriminantAnalysis()" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Model performance metrics\n", | |
"y_true = y_test\n", | |
"y_pred = clf_qda.predict(X_test)\n", | |
"# precision, recall, f1, overall accuracy\n", | |
"print(classification_report(y_true,y_pred))\n", | |
"\n", | |
"y_true, y_pred" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "CytsxxAo3DTv", | |
"outputId": "50f446f7-8e2a-4ee2-ff95-1a722a5a855d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" 1 0.95 0.86 0.90 22\n", | |
" 2 0.95 1.00 0.97 19\n", | |
" 3 0.92 0.96 0.94 25\n", | |
"\n", | |
" accuracy 0.94 66\n", | |
" macro avg 0.94 0.94 0.94 66\n", | |
"weighted avg 0.94 0.94 0.94 66\n", | |
"\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([1, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 1, 1, 2, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 1, 3, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32),\n", | |
" array([3, 3, 2, 2, 3, 3, 1, 3, 1, 3, 1, 2, 3, 3, 2, 2, 1, 2, 1, 2, 3, 3,\n", | |
" 1, 3, 1, 2, 3, 3, 2, 1, 1, 1, 2, 2, 3, 1, 3, 1, 3, 3, 3, 2, 1, 2,\n", | |
" 3, 1, 1, 2, 2, 3, 2, 3, 2, 2, 3, 1, 1, 2, 1, 1, 2, 3, 3, 1, 3, 3],\n", | |
" dtype=int32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"y_score = clf_qda.predict_proba(X_test)\n", | |
"\n", | |
"# Compute ROC curve and ROC area for each class\n", | |
"fpr = dict()\n", | |
"tpr = dict()\n", | |
"roc_auc = dict()\n", | |
"for i in range(n_classes):\n", | |
" fpr[i], tpr[i], _ = roc_curve(y[:, i], y_score[:, i])\n", | |
" roc_auc[i] = auc(fpr[i], tpr[i])\n", | |
"\n", | |
"# Compute micro-average ROC curve and ROC area\n", | |
"fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(y.ravel(), y_score.ravel())\n", | |
"roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n", | |
"\n", | |
"lw = 2\n", | |
"\n", | |
"# First aggregate all false positive rates\n", | |
"all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))\n", | |
"\n", | |
"# Then interpolate all ROC curves at this points\n", | |
"mean_tpr = np.zeros_like(all_fpr)\n", | |
"for i in range(n_classes):\n", | |
" mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n", | |
"\n", | |
"# Finally average it and compute AUC\n", | |
"mean_tpr /= n_classes\n", | |
"\n", | |
"fpr[\"macro\"] = all_fpr\n", | |
"tpr[\"macro\"] = mean_tpr\n", | |
"roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n", | |
"\n", | |
"# Plot all ROC curves\n", | |
"plt.figure(figsize=(14, 10))\n", | |
"plt.plot(\n", | |
" fpr[\"micro\"],\n", | |
" tpr[\"micro\"],\n", | |
" label=\"micro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"micro\"]),\n", | |
" color=\"deeppink\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"plt.plot(\n", | |
" fpr[\"macro\"],\n", | |
" tpr[\"macro\"],\n", | |
" label=\"macro-average ROC curve (area = {0:0.4f})\".format(roc_auc[\"macro\"]),\n", | |
" color=\"navy\",\n", | |
" linestyle=\":\",\n", | |
" linewidth=4,\n", | |
")\n", | |
"\n", | |
"colors = cycle([\"aqua\", \"darkorange\", \"cornflowerblue\"])\n", | |
"for i, color in zip(range(n_classes), colors):\n", | |
" plt.plot(\n", | |
" fpr[i],\n", | |
" tpr[i],\n", | |
" color=color,\n", | |
" lw=lw,\n", | |
" label=\"ROC curve of class {0} (area = {1:0.4f})\".format(i, roc_auc[i]),\n", | |
" )\n", | |
"\n", | |
"plt.plot([0, 1], [0, 1], \"k--\", lw=lw)\n", | |
"plt.xlim([0.0, 1.0])\n", | |
"plt.ylim([0.0, 1.05])\n", | |
"plt.xlabel(\"False Positive Rate\")\n", | |
"plt.ylabel(\"True Positive Rate\")\n", | |
"plt.title(\"Some extension of Receiver operating characteristic to multiclass\")\n", | |
"plt.legend(loc=\"lower right\")\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 621 | |
}, | |
"id": "YeTDiJOB-6Ch", | |
"outputId": "fbc19a9d-0837-401c-f513-e2ffa65ce16b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 1008x720 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"roc_auc_score(y_true, clf_qda.predict_proba(X_test),multi_class=\"ovr\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "uGKN6dkl3K-o", | |
"outputId": "8ab6026e-4f3b-40ce-a266-1e1d32803706" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.9904534796571903" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 28 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Кросс-валидация и выбор модели" | |
], | |
"metadata": { | |
"id": "qkpkew5EC2wM" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# data preparation\n", | |
"X = DATA.loc[:, DATA.columns != 'class'].to_numpy()\n", | |
"y = DATA[\"class\"].to_numpy()\n", | |
"X, y" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Yz49AkUCEPJc", | |
"outputId": "df3314c5-b61a-437f-d76a-7df2b1db2c57" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([[15.26 , 14.84 , 0.871 , ..., 3.312 , 2.221 , 5.22 ],\n", | |
" [14.88 , 14.57 , 0.8811, ..., 3.333 , 1.018 , 4.956 ],\n", | |
" [14.29 , 14.09 , 0.905 , ..., 3.337 , 2.699 , 4.825 ],\n", | |
" ...,\n", | |
" [13.2 , 13.66 , 0.8883, ..., 3.232 , 8.315 , 5.056 ],\n", | |
" [11.84 , 13.21 , 0.8521, ..., 2.836 , 3.598 , 5.044 ],\n", | |
" [12.3 , 13.34 , 0.8684, ..., 2.974 , 5.637 , 5.063 ]]),\n", | |
" array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", | |
" 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", | |
" 3], dtype=int32))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 29 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model_1 = LogisticRegression(penalty='none', multi_class='ovr', solver='newton-cg')\n", | |
"scores = cross_val_score(model_1, X, y, cv=10, scoring='roc_auc_ovr')\n", | |
"print(\"Logit has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"model_2 = LinearDiscriminantAnalysis()\n", | |
"scores = cross_val_score(model_2, X, y, cv=10, scoring='roc_auc_ovr')\n", | |
"print(\"LDA has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"model_3 = QuadraticDiscriminantAnalysis()\n", | |
"scores = cross_val_score(model_3, X, y, cv=10, scoring='roc_auc_ovr')\n", | |
"print(\"QDA has %0.4f OvR AUC with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "VNp9eVsOCS3c", | |
"outputId": "9f3d3634-f942-4dff-9130-4402aaa50c22" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/sklearn/utils/optimize.py:212: ConvergenceWarning: newton-cg failed to converge. Increase the number of iterations.\n", | |
" ConvergenceWarning,\n", | |
"/usr/local/lib/python3.7/dist-packages/sklearn/utils/optimize.py:212: ConvergenceWarning: newton-cg failed to converge. Increase the number of iterations.\n", | |
" ConvergenceWarning,\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Logit has 0.9957 OvR AUC with a standard deviation of 0.01\n", | |
"LDA has 0.9954 OvR AUC with a standard deviation of 0.01\n", | |
"QDA has 0.9879 OvR AUC with a standard deviation of 0.02\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "OohWzlq1EGxK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment