Last active
March 23, 2022 11:48
-
-
Save jamm1985/14e722192382a2919119d4cdb61e4aed to your computer and use it in GitHub Desktop.
Lab_12_intro_to_ML_regression_part_I
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Lab_12_intro_to_machine_learning", | |
"provenance": [], | |
"authorship_tag": "ABX9TyNRcdxwAzXP9o7UZX2LavzS", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jamm1985/14e722192382a2919119d4cdb61e4aed/lab_12_intro_to_machine_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Видео лабораторной: https://youtu.be/r-z1cjvpwBE\n", | |
"\n", | |
"TG: https://t.me/data_science_news\n", | |
"\n", | |
"\n", | |
"\n", | |
"---" | |
], | |
"metadata": { | |
"id": "-xv-tPMotFuR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "Z5BcJ-9oZr5D", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "47a371c5-6905-4fb1-fef8-89024915e987" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"2.8.0\n" | |
] | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import matplotlib.pylab as plt\n", | |
"\n", | |
"from sklearn.linear_model import LinearRegression\n", | |
"from sklearn.linear_model import Ridge\n", | |
"from sklearn.linear_model import Lasso\n", | |
"from sklearn.model_selection import cross_val_score\n", | |
"from sklearn.preprocessing import MaxAbsScaler\n", | |
"from sklearn.model_selection import train_test_split\n", | |
"from sklearn.metrics import mean_squared_error\n", | |
"\n", | |
"import tensorflow as tf\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras import layers\n", | |
"\n", | |
"print(tf.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Разведочный анализ данных\n", | |
"\n", | |
"[Набор данных Copyright 2020 Google LLC](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_a_real_dataset.ipynb)\n", | |
"\n", | |
"[Документация по набору данных](https://developers.google.com/machine-learning/crash-course/california-housing-data-description)" | |
], | |
"metadata": { | |
"id": "fHFFL3xCe7k1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!wget https://download.mlcc.google.com/mledu-datasets/california_housing_train.csv\n", | |
"!head california_housing_train.csv" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "T5vVWNXre1QT", | |
"outputId": "a6700678-2ddd-42e2-c997-08b0d46b60cd" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"--2022-03-16 05:30:23-- https://download.mlcc.google.com/mledu-datasets/california_housing_train.csv\n", | |
"Resolving download.mlcc.google.com (download.mlcc.google.com)... 74.125.133.94, 2a00:1450:400c:c07::5e\n", | |
"Connecting to download.mlcc.google.com (download.mlcc.google.com)|74.125.133.94|:443... connected.\n", | |
"HTTP request sent, awaiting response... 302 Found\n", | |
"Location: https://dl.google.com/mlcc/mledu-datasets/california_housing_train.csv [following]\n", | |
"--2022-03-16 05:30:23-- https://dl.google.com/mlcc/mledu-datasets/california_housing_train.csv\n", | |
"Resolving dl.google.com (dl.google.com)... 64.233.167.91, 64.233.167.136, 64.233.167.93, ...\n", | |
"Connecting to dl.google.com (dl.google.com)|64.233.167.91|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1706430 (1.6M) [text/csv]\n", | |
"Saving to: ‘california_housing_train.csv’\n", | |
"\n", | |
"\r californi 0%[ ] 0 --.-KB/s \rcalifornia_housing_ 100%[===================>] 1.63M --.-KB/s in 0.01s \n", | |
"\n", | |
"2022-03-16 05:30:23 (128 MB/s) - ‘california_housing_train.csv’ saved [1706430/1706430]\n", | |
"\n", | |
"\"longitude\",\"latitude\",\"housing_median_age\",\"total_rooms\",\"total_bedrooms\",\"population\",\"households\",\"median_income\",\"median_house_value\"\n", | |
"-114.310000,34.190000,15.000000,5612.000000,1283.000000,1015.000000,472.000000,1.493600,66900.000000\n", | |
"-114.470000,34.400000,19.000000,7650.000000,1901.000000,1129.000000,463.000000,1.820000,80100.000000\n", | |
"-114.560000,33.690000,17.000000,720.000000,174.000000,333.000000,117.000000,1.650900,85700.000000\n", | |
"-114.570000,33.640000,14.000000,1501.000000,337.000000,515.000000,226.000000,3.191700,73400.000000\n", | |
"-114.570000,33.570000,20.000000,1454.000000,326.000000,624.000000,262.000000,1.925000,65500.000000\n", | |
"-114.580000,33.630000,29.000000,1387.000000,236.000000,671.000000,239.000000,3.343800,74000.000000\n", | |
"-114.580000,33.610000,25.000000,2907.000000,680.000000,1841.000000,633.000000,2.676800,82400.000000\n", | |
"-114.590000,34.830000,41.000000,812.000000,168.000000,375.000000,158.000000,1.708300,48500.000000\n", | |
"-114.590000,33.610000,34.000000,4789.000000,1175.000000,3134.000000,1056.000000,2.178200,58400.000000\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1 = pd.read_csv(\"california_housing_train.csv\")\n", | |
"dataset_1.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "YAmQv4zxfhnL", | |
"outputId": "8f8a8611-37fb-4839-ad10-e9d8c7862416" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" longitude latitude housing_median_age total_rooms \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean -119.562108 35.625225 28.589353 2643.664412 \n", | |
"std 2.005166 2.137340 12.586937 2179.947071 \n", | |
"min -124.350000 32.540000 1.000000 2.000000 \n", | |
"25% -121.790000 33.930000 18.000000 1462.000000 \n", | |
"50% -118.490000 34.250000 29.000000 2127.000000 \n", | |
"75% -118.000000 37.720000 37.000000 3151.250000 \n", | |
"max -114.310000 41.950000 52.000000 37937.000000 \n", | |
"\n", | |
" total_bedrooms population households median_income \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean 539.410824 1429.573941 501.221941 3.883578 \n", | |
"std 421.499452 1147.852959 384.520841 1.908157 \n", | |
"min 1.000000 3.000000 1.000000 0.499900 \n", | |
"25% 297.000000 790.000000 282.000000 2.566375 \n", | |
"50% 434.000000 1167.000000 409.000000 3.544600 \n", | |
"75% 648.250000 1721.000000 605.250000 4.767000 \n", | |
"max 6445.000000 35682.000000 6082.000000 15.000100 \n", | |
"\n", | |
" median_house_value \n", | |
"count 17000.000000 \n", | |
"mean 207300.912353 \n", | |
"std 115983.764387 \n", | |
"min 14999.000000 \n", | |
"25% 119400.000000 \n", | |
"50% 180400.000000 \n", | |
"75% 265000.000000 \n", | |
"max 500001.000000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-726fb9fb-e88d-4f15-bb58-35e88103ed20\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>longitude</th>\n", | |
" <th>latitude</th>\n", | |
" <th>housing_median_age</th>\n", | |
" <th>total_rooms</th>\n", | |
" <th>total_bedrooms</th>\n", | |
" <th>population</th>\n", | |
" <th>households</th>\n", | |
" <th>median_income</th>\n", | |
" <th>median_house_value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>-119.562108</td>\n", | |
" <td>35.625225</td>\n", | |
" <td>28.589353</td>\n", | |
" <td>2643.664412</td>\n", | |
" <td>539.410824</td>\n", | |
" <td>1429.573941</td>\n", | |
" <td>501.221941</td>\n", | |
" <td>3.883578</td>\n", | |
" <td>207300.912353</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.005166</td>\n", | |
" <td>2.137340</td>\n", | |
" <td>12.586937</td>\n", | |
" <td>2179.947071</td>\n", | |
" <td>421.499452</td>\n", | |
" <td>1147.852959</td>\n", | |
" <td>384.520841</td>\n", | |
" <td>1.908157</td>\n", | |
" <td>115983.764387</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>-124.350000</td>\n", | |
" <td>32.540000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>2.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.499900</td>\n", | |
" <td>14999.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>-121.790000</td>\n", | |
" <td>33.930000</td>\n", | |
" <td>18.000000</td>\n", | |
" <td>1462.000000</td>\n", | |
" <td>297.000000</td>\n", | |
" <td>790.000000</td>\n", | |
" <td>282.000000</td>\n", | |
" <td>2.566375</td>\n", | |
" <td>119400.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>-118.490000</td>\n", | |
" <td>34.250000</td>\n", | |
" <td>29.000000</td>\n", | |
" <td>2127.000000</td>\n", | |
" <td>434.000000</td>\n", | |
" <td>1167.000000</td>\n", | |
" <td>409.000000</td>\n", | |
" <td>3.544600</td>\n", | |
" <td>180400.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>-118.000000</td>\n", | |
" <td>37.720000</td>\n", | |
" <td>37.000000</td>\n", | |
" <td>3151.250000</td>\n", | |
" <td>648.250000</td>\n", | |
" <td>1721.000000</td>\n", | |
" <td>605.250000</td>\n", | |
" <td>4.767000</td>\n", | |
" <td>265000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>-114.310000</td>\n", | |
" <td>41.950000</td>\n", | |
" <td>52.000000</td>\n", | |
" <td>37937.000000</td>\n", | |
" <td>6445.000000</td>\n", | |
" <td>35682.000000</td>\n", | |
" <td>6082.000000</td>\n", | |
" <td>15.000100</td>\n", | |
" <td>500001.000000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-726fb9fb-e88d-4f15-bb58-35e88103ed20')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-726fb9fb-e88d-4f15-bb58-35e88103ed20 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-726fb9fb-e88d-4f15-bb58-35e88103ed20');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1['median_house_value'] = dataset_1['median_house_value']/1000\n", | |
"dataset_1.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "mLYfDfOkjKcn", | |
"outputId": "c732e083-877c-4788-f041-d6842aae6f69" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" longitude latitude housing_median_age total_rooms \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean -119.562108 35.625225 28.589353 2643.664412 \n", | |
"std 2.005166 2.137340 12.586937 2179.947071 \n", | |
"min -124.350000 32.540000 1.000000 2.000000 \n", | |
"25% -121.790000 33.930000 18.000000 1462.000000 \n", | |
"50% -118.490000 34.250000 29.000000 2127.000000 \n", | |
"75% -118.000000 37.720000 37.000000 3151.250000 \n", | |
"max -114.310000 41.950000 52.000000 37937.000000 \n", | |
"\n", | |
" total_bedrooms population households median_income \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean 539.410824 1429.573941 501.221941 3.883578 \n", | |
"std 421.499452 1147.852959 384.520841 1.908157 \n", | |
"min 1.000000 3.000000 1.000000 0.499900 \n", | |
"25% 297.000000 790.000000 282.000000 2.566375 \n", | |
"50% 434.000000 1167.000000 409.000000 3.544600 \n", | |
"75% 648.250000 1721.000000 605.250000 4.767000 \n", | |
"max 6445.000000 35682.000000 6082.000000 15.000100 \n", | |
"\n", | |
" median_house_value \n", | |
"count 17000.000000 \n", | |
"mean 207.300912 \n", | |
"std 115.983764 \n", | |
"min 14.999000 \n", | |
"25% 119.400000 \n", | |
"50% 180.400000 \n", | |
"75% 265.000000 \n", | |
"max 500.001000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-46692847-1b09-4381-8b12-8c1f463665fb\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>longitude</th>\n", | |
" <th>latitude</th>\n", | |
" <th>housing_median_age</th>\n", | |
" <th>total_rooms</th>\n", | |
" <th>total_bedrooms</th>\n", | |
" <th>population</th>\n", | |
" <th>households</th>\n", | |
" <th>median_income</th>\n", | |
" <th>median_house_value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>-119.562108</td>\n", | |
" <td>35.625225</td>\n", | |
" <td>28.589353</td>\n", | |
" <td>2643.664412</td>\n", | |
" <td>539.410824</td>\n", | |
" <td>1429.573941</td>\n", | |
" <td>501.221941</td>\n", | |
" <td>3.883578</td>\n", | |
" <td>207.300912</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.005166</td>\n", | |
" <td>2.137340</td>\n", | |
" <td>12.586937</td>\n", | |
" <td>2179.947071</td>\n", | |
" <td>421.499452</td>\n", | |
" <td>1147.852959</td>\n", | |
" <td>384.520841</td>\n", | |
" <td>1.908157</td>\n", | |
" <td>115.983764</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>-124.350000</td>\n", | |
" <td>32.540000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>2.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.499900</td>\n", | |
" <td>14.999000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>-121.790000</td>\n", | |
" <td>33.930000</td>\n", | |
" <td>18.000000</td>\n", | |
" <td>1462.000000</td>\n", | |
" <td>297.000000</td>\n", | |
" <td>790.000000</td>\n", | |
" <td>282.000000</td>\n", | |
" <td>2.566375</td>\n", | |
" <td>119.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>-118.490000</td>\n", | |
" <td>34.250000</td>\n", | |
" <td>29.000000</td>\n", | |
" <td>2127.000000</td>\n", | |
" <td>434.000000</td>\n", | |
" <td>1167.000000</td>\n", | |
" <td>409.000000</td>\n", | |
" <td>3.544600</td>\n", | |
" <td>180.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>-118.000000</td>\n", | |
" <td>37.720000</td>\n", | |
" <td>37.000000</td>\n", | |
" <td>3151.250000</td>\n", | |
" <td>648.250000</td>\n", | |
" <td>1721.000000</td>\n", | |
" <td>605.250000</td>\n", | |
" <td>4.767000</td>\n", | |
" <td>265.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>-114.310000</td>\n", | |
" <td>41.950000</td>\n", | |
" <td>52.000000</td>\n", | |
" <td>37937.000000</td>\n", | |
" <td>6445.000000</td>\n", | |
" <td>35682.000000</td>\n", | |
" <td>6082.000000</td>\n", | |
" <td>15.000100</td>\n", | |
" <td>500.001000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-46692847-1b09-4381-8b12-8c1f463665fb')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-46692847-1b09-4381-8b12-8c1f463665fb button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-46692847-1b09-4381-8b12-8c1f463665fb');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1.corr()['median_house_value']" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "WRci842_fzQy", | |
"outputId": "39ce79cc-6b47-46f6-897f-3786df05a3cd" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"longitude -0.044982\n", | |
"latitude -0.144917\n", | |
"housing_median_age 0.106758\n", | |
"total_rooms 0.130991\n", | |
"total_bedrooms 0.045783\n", | |
"population -0.027850\n", | |
"households 0.061031\n", | |
"median_income 0.691871\n", | |
"median_house_value 1.000000\n", | |
"Name: median_house_value, dtype: float64" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Задача регрессии (линейная модель)\n", | |
"\n", | |
"[Лабораторная №7 - линейная регрессия](https://www.youtube.com/watch?v=txDLkiesqpY)\n" | |
], | |
"metadata": { | |
"id": "fG3lpBVZe3jm" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Подготовка данных" | |
], | |
"metadata": { | |
"id": "72ujRJiXL5Le" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X = dataset_1.loc[:, dataset_1.columns != 'median_house_value'].to_numpy()\n", | |
"y = dataset_1['median_house_value'].to_numpy()" | |
], | |
"metadata": { | |
"id": "AGwVELBnoPbC" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "F3NsKaY3on3Y", | |
"outputId": "7d74e34d-a2b7-4f34-9c91-d148eea39f73" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-114.31 , 34.19 , 15. , ..., 1015. , 472. ,\n", | |
" 1.4936],\n", | |
" [-114.47 , 34.4 , 19. , ..., 1129. , 463. ,\n", | |
" 1.82 ],\n", | |
" [-114.56 , 33.69 , 17. , ..., 333. , 117. ,\n", | |
" 1.6509],\n", | |
" ...,\n", | |
" [-124.3 , 41.84 , 17. , ..., 1244. , 456. ,\n", | |
" 3.0313],\n", | |
" [-124.3 , 41.8 , 19. , ..., 1298. , 478. ,\n", | |
" 1.9797],\n", | |
" [-124.35 , 40.54 , 52. , ..., 806. , 270. ,\n", | |
" 3.0147]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"y" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EbchplGSprI9", | |
"outputId": "b75e5034-8143-4ba4-e884-a5cd017680fc" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([ 66.9, 80.1, 85.7, ..., 103.6, 85.8, 94.6])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## resampling (cross validation)\n", | |
"\n", | |
"\n", | |
"\n", | |
"https://scikit-learn.org/stable/modules/cross_validation.html" | |
], | |
"metadata": { | |
"id": "q1ojgDeYvFoU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "H0BUsnbypsUs" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"$\\mathrm{MSE}=\\frac{1}{N}\\sum_{i=1}^n(y_i-\\hat{y}_i)^2$" | |
], | |
"metadata": { | |
"id": "_DKeB_iO80ts" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=2)\n", | |
"print(\"Dims of train \", X_train.shape, y_train.shape)\n", | |
"print(\"Dims of test \", X_test.shape, y_test.shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EqbEwqE6qYtO", | |
"outputId": "c0bca208-5fb7-4056-a3ff-a28c5144ed38" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Dims of train (11900, 8) (11900,)\n", | |
"Dims of test (5100, 8) (5100,)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html\n", | |
"# fit model step 1\n", | |
"\n", | |
"lin_reg = LinearRegression().fit(X_train, y_train)\n", | |
"\n", | |
"print(\"R^2 on train set {}\".format(lin_reg.score(X_train, y_train)))\n", | |
"print(\"R^2 on test set {}\".format(lin_reg.score(X_test, y_test)))\n", | |
"\n", | |
"print(\"MSE on train set {}\".format(\n", | |
" mean_squared_error(y_train, lin_reg.predict(X_train))\n", | |
" )\n", | |
")\n", | |
"print(\"MSE on test set {}\".format(\n", | |
" mean_squared_error(y_test, lin_reg.predict(X_test))\n", | |
" )\n", | |
")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "bJX_Yqvjqtg3", | |
"outputId": "95a1e961-e6ab-4922-bd32-1e78a3d5d891" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"R^2 on train set 0.6375851309751237\n", | |
"R^2 on test set 0.6492538345819212\n", | |
"MSE on train set 4803.095967423598\n", | |
"MSE on test set 4880.3979596385025\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lin_reg.coef_, lin_reg.intercept_" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "D5gbhax50OQS", | |
"outputId": "e3253da9-6531-40ff-c30f-69748eed9e13" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(array([-4.29214290e+01, -4.24176134e+01, 1.16110862e+00, -8.98362849e-03,\n", | |
" 1.15185032e-01, -3.66694219e-02, 4.69730379e-02, 4.03237644e+01]),\n", | |
" -3612.7404586294624)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# resampling k-fold cross validation\n", | |
"# https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_score.html#sklearn.model_selection.cross_val_score\n", | |
"\n", | |
"\n", | |
"# fit model step 2\n", | |
"\n", | |
"lin_reg_2 = LinearRegression()\n", | |
"\n", | |
"scores = cross_val_score(lin_reg_2, X, y, cv=5)\n", | |
"scores" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mfIRTg8ysVGY", | |
"outputId": "45aa6ffb-dae2-4e0e-cbc1-328ddb73aee6" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([0.52201778, 0.56428342, 0.60260364, 0.39679821, 0.65576199])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 12 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_fwlqMYtui16", | |
"outputId": "11895a96-3cec-44dd-8dbe-a93d525f956c" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0.55 R^2 with a standard deviation of 0.09\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lin_reg_2 = LinearRegression()\n", | |
"scores = cross_val_score(lin_reg_2, X, y, cv=5, scoring='neg_mean_squared_error')\n", | |
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s_5kGTf2-KY5", | |
"outputId": "a4aa2692-5bbb-490a-e198-471b2aab5351" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"-5352.33 MSE with a standard deviation of 1329.57\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"###Регуляризация в задаче линейной регрессии\n", | |
"\n", | |
"Пусть $X$ данные в матричной форме:\n", | |
"\n", | |
"$$\\bf{X}=\\left[ \\begin{matrix} 1 & X_{11} & X_{12} & ... & X_{1k}\\\\ 1& X_{21} & X_{22} & ... & X_{2k} \\\\ ... & ... & ... & ... & ... \\\\ 1 & X_{n1} & X_{n2} & ... & X_{nk} \\end{matrix} \\right]$$\n", | |
"\n", | |
"$Y$, $\\beta$, $\\epsilon$ зависимая переменная, коэффициенты модели и вектор ошибок соответсвенно: \n", | |
"\n", | |
"$$\\bf{Y}=\\left[ \\begin{matrix} y_1 \\\\ y_2 \\\\ ... \\\\ y_n \\end{matrix} \\right], \\bf{\\beta}=\\left[ \\begin{matrix} \\beta_0 \\\\ \\beta_1 \\\\ ... \\\\ \\beta_k \\end{matrix} \\right], \\bf{\\epsilon}=\\left[ \\begin{matrix} \\epsilon_1 \\\\ \\epsilon_2 \\\\ ... \\\\ \\epsilon_n \\end{matrix} \\right]$$\n", | |
"\n", | |
"Тогда модель линейной регрессии в общем виде: \n", | |
"\n", | |
"$$Y = \\bf{X \\beta + \\epsilon}$$\n", | |
"\n", | |
"Пусть $\\hat{y}=f(X)=X\\hat{\\beta}$ - оценка параметров $\\beta$ для линейной модели.\n", | |
"\n", | |
"Зафиксируем целевую функцию (objective function): \n", | |
"\n", | |
"$$\\mathrm{RSS}(\\beta)=\\sum_{i=1}^n(y_i - f(X_i))^2=(y-X\\beta)^T(y-X\\beta)$$\n", | |
"\n", | |
"Тогда $\\hat{\\beta}^{\\mathrm{OLS}}=\\mathrm{argmin}_\\beta\\{\\sum_{i=1}^n(y_i - f(X_i))^2\\}$ оценка параметров линейной модели методом наименьших квадратов.\n", | |
"\n", | |
"$$\\hat{\\beta}^{\\mathrm{OLS}}=(X^TX)^{-1}X^Ty$$ - решение в общем виде для оценки методом наименьших кварратов.\n", | |
"\n", | |
"### Ridge регуляризация\n", | |
"\n", | |
"Пусть $\\mathrm{RSS}(\\beta)=\\sum_{i=1}^n(y_i - f(X_i))^2+\\lambda\\sum_{i=1}^k\\beta_j^2=(y-X\\beta)^T(y-X\\beta)+\\lambda\\beta^T\\beta$ - целевая функция с параметром регуляризации $\\lambda\\beta^T\\beta$, тогда\n", | |
"\n", | |
"$\\hat{\\beta}^{\\mathrm{ridge}}=\\mathrm{argmin}_\\beta\\{\\sum_{i=1}^n(y_i - f(X_i))^2+\\lambda\\sum_{i=1}^k\\beta_j^2\\}$ оценка параметров линейной модели c $L_2$ регуляризацией.\n", | |
"\n", | |
"$$\\hat{\\beta}^{\\mathrm{ridge}}=(X^TX+\\lambda I)^{-1}X^Ty$$ - решение в общем виде для оценки c $L_2$ регуляризацией.\n", | |
"\n", | |
"### LASSO регуляризация\n", | |
"\n", | |
"Пусть $\\mathrm{RSS}(\\beta)=\\sum_{i=1}^n(y_i - f(X_i))^2+\\lambda\\sum_{i=1}^k|\\beta_j|=(y-X\\beta)^T(y-X\\beta)+\\lambda||\\beta||_1$ - целевая функция с парамтером регуляризации $\\lambda||\\beta||_1$, тогда\n", | |
"\n", | |
"$\\hat{\\beta}^{\\mathrm{lasso}}=\\mathrm{argmin}_\\beta\\{\\frac{1}{2}\\sum_{i=1}^n(y_i - f(X_i))^2+\\lambda\\sum_{i=1}^k|\\beta_j|\\}$ оценка параметров линейной модели c $L_1$ регуляризацией.\n", | |
"\n", | |
"Ддя $\\hat{\\beta}^{\\mathrm{lasso}}$ не существут решения в общем виде, однако, оптимальная оценка может быть найдена методами оптимизации квадратичных функций." | |
], | |
"metadata": { | |
"id": "mjdeEgVEgj9X" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Ridge.html\n", | |
"\n", | |
"# fit model Ridge\n", | |
"\n", | |
"ridge_reg_1 = Ridge(alpha=1.0, solver='cholesky')\n", | |
"scores = cross_val_score(ridge_reg_1, X, y, cv=5, scoring='r2')\n", | |
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"ridge_reg_2 = Ridge(alpha=1.0, solver='cholesky')\n", | |
"scores = cross_val_score(ridge_reg_2, X, y, cv=5, scoring='neg_mean_squared_error')\n", | |
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"print(\"\\n\")\n", | |
"\n", | |
"print(\"--------------- RIDGE coeffs alpha=1.0\")\n", | |
"ridge_reg_3 = Ridge(alpha=1.0, solver='cholesky').fit(X_train,y_train)\n", | |
"print(ridge_reg_3.coef_, ridge_reg_3.intercept_, \"\\n\")\n", | |
"\n", | |
"\n", | |
"print(\"--------------- RIDGE coeffs alpha=10.0\")\n", | |
"ridge_reg_4 = Ridge(alpha=10.0, solver='cholesky').fit(X_train,y_train)\n", | |
"print(ridge_reg_4.coef_, ridge_reg_4.intercept_, \"\\n\")\n", | |
"\n", | |
"print(\"--------------- RIDGE coeffs alpha=100.0\")\n", | |
"ridge_reg_5 = Ridge(alpha=100.0, solver='cholesky').fit(X_train,y_train)\n", | |
"print(ridge_reg_5.coef_, ridge_reg_5.intercept_, \"\\n\")\n", | |
"\n", | |
"print(\"--------------- OLS coeffs\")\n", | |
"print(lin_reg.coef_, lin_reg.intercept_)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "fNwBKze-7lqL", | |
"outputId": "65c5efa7-1e14-4a99-8c89-76c51d386563" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0.55 R^2 with a standard deviation of 0.09\n", | |
"-5352.03 MSE with a standard deviation of 1329.88\n", | |
"\n", | |
"\n", | |
"--------------- RIDGE coeffs alpha=1.0\n", | |
"[-4.29074407e+01 -4.24045336e+01 1.16133727e+00 -8.98484030e-03\n", | |
" 1.15166492e-01 -3.66696731e-02 4.70033197e-02 4.03245856e+01] -3611.5452675178294 \n", | |
"\n", | |
"--------------- RIDGE coeffs alpha=10.0\n", | |
"[-4.27819790e+01 -4.22872240e+01 1.16338783e+00 -8.99564254e-03\n", | |
" 1.14999813e-01 -3.66719551e-02 4.72750721e-02 4.03319039e+01] -3600.825268345685 \n", | |
"\n", | |
"--------------- RIDGE coeffs alpha=100.0\n", | |
"[-4.15688113e+01 -4.11531876e+01 1.18319623e+00 -9.09373770e-03\n", | |
" 1.13350796e-01 -3.66968047e-02 4.99175259e-02 4.03981251e+01] -3497.137932761225 \n", | |
"\n", | |
"--------------- OLS coeffs\n", | |
"[-4.29214290e+01 -4.24176134e+01 1.16110862e+00 -8.98362849e-03\n", | |
" 1.15185032e-01 -3.66694219e-02 4.69730379e-02 4.03237644e+01] -3612.7404586294624\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html\n", | |
"\n", | |
"# fit model Lasso\n", | |
"\n", | |
"lasso_reg_1 = Lasso(alpha=1.0)\n", | |
"scores = cross_val_score(lasso_reg_1, X, y, cv=5, scoring='r2')\n", | |
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"lasso_reg_2 = Lasso(alpha=1.0)\n", | |
"scores = cross_val_score(lasso_reg_2, X, y, cv=5, scoring='neg_mean_squared_error')\n", | |
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"print(\"\\n\")\n", | |
"\n", | |
"print(\"--------------- LASSO coeffs alpha=1.0\")\n", | |
"lasso_reg_3 = Lasso(alpha=1.0).fit(X_train,y_train)\n", | |
"print(lasso_reg_3.coef_, lasso_reg_3.intercept_, \"\\n\")\n", | |
"\n", | |
"print(\"--------------- LASSO coeffs alpha=10.0\")\n", | |
"lasso_reg_4 = Lasso(alpha=10.0).fit(X_train,y_train)\n", | |
"print(lasso_reg_4.coef_, lasso_reg_4.intercept_, \"\\n\")\n", | |
"\n", | |
"print(\"--------------- LASSO coeffs alpha=100.0\")\n", | |
"lasso_reg_5 = Lasso(alpha=100.0).fit(X_train,y_train)\n", | |
"print(lasso_reg_5.coef_, lasso_reg_5.intercept_, \"\\n\")\n", | |
"\n", | |
"print(\"--------------- OLS coeffs\")\n", | |
"print(lin_reg.coef_, lin_reg.intercept_)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Rn6JohJ7GqM9", | |
"outputId": "f4c099b8-819b-4c42-8931-bc6c33ba81af" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0.56 R^2 with a standard deviation of 0.07\n", | |
"-5318.24 MSE with a standard deviation of 1427.98\n", | |
"\n", | |
"\n", | |
"--------------- LASSO coeffs alpha=1.0\n", | |
"[-3.90786668e+01 -3.88232746e+01 1.21680467e+00 -9.28340186e-03\n", | |
" 1.09766158e-01 -3.67483260e-02 5.54222354e-02 4.05165803e+01] -3284.0806183661234 \n", | |
"\n", | |
"--------------- LASSO coeffs alpha=10.0\n", | |
"[-4.49374412e+00 -6.47412275e+00 1.71806994e+00 -1.19820710e-02\n", | |
" 6.10033331e-02 -3.74577344e-02 1.31459141e-01 4.22521234e+01] -326.13902350915623 \n", | |
"\n", | |
"--------------- LASSO coeffs alpha=100.0\n", | |
"[-0. -0. 1.00948049 0.03475236 -0.21063508 -0.05565397\n", | |
" 0.22692586 6.0528728 ] 142.71153243167387 \n", | |
"\n", | |
"--------------- OLS coeffs\n", | |
"[-4.29214290e+01 -4.24176134e+01 1.16110862e+00 -8.98362849e-03\n", | |
" 1.15185032e-01 -3.66694219e-02 4.69730379e-02 4.03237644e+01] -3612.7404586294624\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Конструирование признаков (Feature engineering)" | |
], | |
"metadata": { | |
"id": "EqxNm2YENBIB" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Rescale" | |
], | |
"metadata": { | |
"id": "Vt4pawzSPYIF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html\n", | |
"abs_norm = MaxAbsScaler().fit(X)\n", | |
"X_normalized = abs_norm.transform(X)\n", | |
"X_normalized" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "JQRsDv-NJIRs", | |
"outputId": "70645f80-bca8-48ee-dcdd-40fe1aab19b6" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-0.91926015, 0.81501788, 0.28846154, ..., 0.02844571,\n", | |
" 0.07760605, 0.09957267],\n", | |
" [-0.92054684, 0.82002384, 0.36538462, ..., 0.0316406 ,\n", | |
" 0.07612627, 0.12133252],\n", | |
" [-0.92127061, 0.80309893, 0.32692308, ..., 0.00933244,\n", | |
" 0.01923709, 0.11005927],\n", | |
" ...,\n", | |
" [-0.99959791, 0.99737783, 0.32692308, ..., 0.03486352,\n", | |
" 0.07497534, 0.20208532],\n", | |
" [-0.99959791, 0.99642431, 0.36538462, ..., 0.03637688,\n", | |
" 0.07859257, 0.13197912],\n", | |
" [-1. , 0.96638856, 1. , ..., 0.02258842,\n", | |
" 0.04439329, 0.20097866]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lin_reg_3 = LinearRegression()\n", | |
"\n", | |
"scores = cross_val_score(lin_reg_3, X_normalized, y, cv=5, scoring='r2')\n", | |
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"scores = cross_val_score(lin_reg_3, X_normalized, y, cv=5, scoring='neg_mean_squared_error')\n", | |
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LEXldutWNYit", | |
"outputId": "a10a0bf3-4004-44ca-8949-a6dbd02fcdf9" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0.55 R^2 with a standard deviation of 0.09\n", | |
"-5352.33 MSE with a standard deviation of 1329.57\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Nonlinearity" | |
], | |
"metadata": { | |
"id": "BgbHHmq_Pd84" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "Gw3ZfTppPUER", | |
"outputId": "f89e7245-916e-48b8-9a45-e850ba016d8c" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" longitude latitude housing_median_age total_rooms \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean -119.562108 35.625225 28.589353 2643.664412 \n", | |
"std 2.005166 2.137340 12.586937 2179.947071 \n", | |
"min -124.350000 32.540000 1.000000 2.000000 \n", | |
"25% -121.790000 33.930000 18.000000 1462.000000 \n", | |
"50% -118.490000 34.250000 29.000000 2127.000000 \n", | |
"75% -118.000000 37.720000 37.000000 3151.250000 \n", | |
"max -114.310000 41.950000 52.000000 37937.000000 \n", | |
"\n", | |
" total_bedrooms population households median_income \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean 539.410824 1429.573941 501.221941 3.883578 \n", | |
"std 421.499452 1147.852959 384.520841 1.908157 \n", | |
"min 1.000000 3.000000 1.000000 0.499900 \n", | |
"25% 297.000000 790.000000 282.000000 2.566375 \n", | |
"50% 434.000000 1167.000000 409.000000 3.544600 \n", | |
"75% 648.250000 1721.000000 605.250000 4.767000 \n", | |
"max 6445.000000 35682.000000 6082.000000 15.000100 \n", | |
"\n", | |
" median_house_value \n", | |
"count 17000.000000 \n", | |
"mean 207.300912 \n", | |
"std 115.983764 \n", | |
"min 14.999000 \n", | |
"25% 119.400000 \n", | |
"50% 180.400000 \n", | |
"75% 265.000000 \n", | |
"max 500.001000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-f8bc8975-24c0-4138-91f5-42f08bbd54cb\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>longitude</th>\n", | |
" <th>latitude</th>\n", | |
" <th>housing_median_age</th>\n", | |
" <th>total_rooms</th>\n", | |
" <th>total_bedrooms</th>\n", | |
" <th>population</th>\n", | |
" <th>households</th>\n", | |
" <th>median_income</th>\n", | |
" <th>median_house_value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>-119.562108</td>\n", | |
" <td>35.625225</td>\n", | |
" <td>28.589353</td>\n", | |
" <td>2643.664412</td>\n", | |
" <td>539.410824</td>\n", | |
" <td>1429.573941</td>\n", | |
" <td>501.221941</td>\n", | |
" <td>3.883578</td>\n", | |
" <td>207.300912</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.005166</td>\n", | |
" <td>2.137340</td>\n", | |
" <td>12.586937</td>\n", | |
" <td>2179.947071</td>\n", | |
" <td>421.499452</td>\n", | |
" <td>1147.852959</td>\n", | |
" <td>384.520841</td>\n", | |
" <td>1.908157</td>\n", | |
" <td>115.983764</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>-124.350000</td>\n", | |
" <td>32.540000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>2.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.499900</td>\n", | |
" <td>14.999000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>-121.790000</td>\n", | |
" <td>33.930000</td>\n", | |
" <td>18.000000</td>\n", | |
" <td>1462.000000</td>\n", | |
" <td>297.000000</td>\n", | |
" <td>790.000000</td>\n", | |
" <td>282.000000</td>\n", | |
" <td>2.566375</td>\n", | |
" <td>119.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>-118.490000</td>\n", | |
" <td>34.250000</td>\n", | |
" <td>29.000000</td>\n", | |
" <td>2127.000000</td>\n", | |
" <td>434.000000</td>\n", | |
" <td>1167.000000</td>\n", | |
" <td>409.000000</td>\n", | |
" <td>3.544600</td>\n", | |
" <td>180.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>-118.000000</td>\n", | |
" <td>37.720000</td>\n", | |
" <td>37.000000</td>\n", | |
" <td>3151.250000</td>\n", | |
" <td>648.250000</td>\n", | |
" <td>1721.000000</td>\n", | |
" <td>605.250000</td>\n", | |
" <td>4.767000</td>\n", | |
" <td>265.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>-114.310000</td>\n", | |
" <td>41.950000</td>\n", | |
" <td>52.000000</td>\n", | |
" <td>37937.000000</td>\n", | |
" <td>6445.000000</td>\n", | |
" <td>35682.000000</td>\n", | |
" <td>6082.000000</td>\n", | |
" <td>15.000100</td>\n", | |
" <td>500.001000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f8bc8975-24c0-4138-91f5-42f08bbd54cb')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-f8bc8975-24c0-4138-91f5-42f08bbd54cb button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-f8bc8975-24c0-4138-91f5-42f08bbd54cb');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1['total_rooms'] = np.log(dataset_1['total_rooms'])\n", | |
"dataset_1['total_bedrooms'] = np.log(dataset_1['total_bedrooms'])\n", | |
"dataset_1['population'] = np.log(dataset_1['population'])\n", | |
"dataset_1['households'] = np.log(dataset_1['households'])\n", | |
"dataset_1.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "iEVcRnCAPrSB", | |
"outputId": "24ac541a-ab30-47b4-c5d6-5851764d8187" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" longitude latitude housing_median_age total_rooms \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean -119.562108 35.625225 28.589353 7.634250 \n", | |
"std 2.005166 2.137340 12.586937 0.742704 \n", | |
"min -124.350000 32.540000 1.000000 0.693147 \n", | |
"25% -121.790000 33.930000 18.000000 7.287561 \n", | |
"50% -118.490000 34.250000 29.000000 7.662468 \n", | |
"75% -118.000000 37.720000 37.000000 8.055554 \n", | |
"max -114.310000 41.950000 52.000000 10.543682 \n", | |
"\n", | |
" total_bedrooms population households median_income \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean 6.055277 7.027439 5.984960 3.883578 \n", | |
"std 0.726624 0.732922 0.727605 1.908157 \n", | |
"min 0.000000 1.098612 0.000000 0.499900 \n", | |
"25% 5.693732 6.672033 5.641907 2.566375 \n", | |
"50% 6.073045 7.062192 6.013715 3.544600 \n", | |
"75% 6.474276 7.450661 6.405641 4.767000 \n", | |
"max 8.771060 10.482402 8.713089 15.000100 \n", | |
"\n", | |
" median_house_value \n", | |
"count 17000.000000 \n", | |
"mean 207.300912 \n", | |
"std 115.983764 \n", | |
"min 14.999000 \n", | |
"25% 119.400000 \n", | |
"50% 180.400000 \n", | |
"75% 265.000000 \n", | |
"max 500.001000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-203b3666-8591-4c23-baab-3556e84920aa\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>longitude</th>\n", | |
" <th>latitude</th>\n", | |
" <th>housing_median_age</th>\n", | |
" <th>total_rooms</th>\n", | |
" <th>total_bedrooms</th>\n", | |
" <th>population</th>\n", | |
" <th>households</th>\n", | |
" <th>median_income</th>\n", | |
" <th>median_house_value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>-119.562108</td>\n", | |
" <td>35.625225</td>\n", | |
" <td>28.589353</td>\n", | |
" <td>7.634250</td>\n", | |
" <td>6.055277</td>\n", | |
" <td>7.027439</td>\n", | |
" <td>5.984960</td>\n", | |
" <td>3.883578</td>\n", | |
" <td>207.300912</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.005166</td>\n", | |
" <td>2.137340</td>\n", | |
" <td>12.586937</td>\n", | |
" <td>0.742704</td>\n", | |
" <td>0.726624</td>\n", | |
" <td>0.732922</td>\n", | |
" <td>0.727605</td>\n", | |
" <td>1.908157</td>\n", | |
" <td>115.983764</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>-124.350000</td>\n", | |
" <td>32.540000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.693147</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>1.098612</td>\n", | |
" <td>0.000000</td>\n", | |
" <td>0.499900</td>\n", | |
" <td>14.999000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>-121.790000</td>\n", | |
" <td>33.930000</td>\n", | |
" <td>18.000000</td>\n", | |
" <td>7.287561</td>\n", | |
" <td>5.693732</td>\n", | |
" <td>6.672033</td>\n", | |
" <td>5.641907</td>\n", | |
" <td>2.566375</td>\n", | |
" <td>119.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>-118.490000</td>\n", | |
" <td>34.250000</td>\n", | |
" <td>29.000000</td>\n", | |
" <td>7.662468</td>\n", | |
" <td>6.073045</td>\n", | |
" <td>7.062192</td>\n", | |
" <td>6.013715</td>\n", | |
" <td>3.544600</td>\n", | |
" <td>180.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>-118.000000</td>\n", | |
" <td>37.720000</td>\n", | |
" <td>37.000000</td>\n", | |
" <td>8.055554</td>\n", | |
" <td>6.474276</td>\n", | |
" <td>7.450661</td>\n", | |
" <td>6.405641</td>\n", | |
" <td>4.767000</td>\n", | |
" <td>265.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>-114.310000</td>\n", | |
" <td>41.950000</td>\n", | |
" <td>52.000000</td>\n", | |
" <td>10.543682</td>\n", | |
" <td>8.771060</td>\n", | |
" <td>10.482402</td>\n", | |
" <td>8.713089</td>\n", | |
" <td>15.000100</td>\n", | |
" <td>500.001000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-203b3666-8591-4c23-baab-3556e84920aa')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-203b3666-8591-4c23-baab-3556e84920aa button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-203b3666-8591-4c23-baab-3556e84920aa');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X_log = dataset_1.loc[:, dataset_1.columns != 'median_house_value'].to_numpy()\n", | |
"X_log" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "w_oNZW7yOegy", | |
"outputId": "58f52926-788e-4fc6-b2d5-d592ff65f380" | |
}, | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-114.31 , 34.19 , 15. , ..., 6.92264389,\n", | |
" 6.15697899, 1.4936 ],\n", | |
" [-114.47 , 34.4 , 19. , ..., 7.02908756,\n", | |
" 6.13772705, 1.82 ],\n", | |
" [-114.56 , 33.69 , 17. , ..., 5.80814249,\n", | |
" 4.76217393, 1.6509 ],\n", | |
" ...,\n", | |
" [-124.3 , 41.84 , 17. , ..., 7.12608727,\n", | |
" 6.12249281, 3.0313 ],\n", | |
" [-124.3 , 41.8 , 19. , ..., 7.1685799 ,\n", | |
" 6.16961073, 1.9797 ],\n", | |
" [-124.35 , 40.54 , 52. , ..., 6.69208374,\n", | |
" 5.59842196, 3.0147 ]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lin_reg_3 = LinearRegression()\n", | |
"\n", | |
"scores = cross_val_score(lin_reg_3, X_log, y, cv=5, scoring='r2')\n", | |
"print(\"%0.2f R^2 with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))\n", | |
"\n", | |
"scores = cross_val_score(lin_reg_3, X_log, y, cv=5, scoring='neg_mean_squared_error')\n", | |
"print(\"%0.2f MSE with a standard deviation of %0.2f\" % (scores.mean(), scores.std()))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GcjUzXfFO0K2", | |
"outputId": "ebe9df39-e4f2-4c3e-c610-d49036b2f3a9" | |
}, | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0.59 R^2 with a standard deviation of 0.08\n", | |
"-4880.92 MSE with a standard deviation of 1192.24\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Задача регрессии (простая нейронная сеть)" | |
], | |
"metadata": { | |
"id": "kkG2-rbFQmuq" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1 = pd.read_csv(\"california_housing_train.csv\")\n", | |
"X = dataset_1.loc[:, dataset_1.columns != 'median_house_value'].to_numpy()\n", | |
"y = dataset_1['median_house_value'].to_numpy()\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=2)\n", | |
"print(\"Dims of train \", X_train.shape, y_train.shape)\n", | |
"print(\"Dims of test \", X_test.shape, y_test.shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5qavR5MfXsy1", | |
"outputId": "aa3f86e0-c715-4acb-eb57-b605af844440" | |
}, | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Dims of train (11900, 8) (11900,)\n", | |
"Dims of test (5100, 8) (5100,)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense\n", | |
"# https://keras.io/guides/functional_api/\n", | |
"\n", | |
"\n", | |
"input = keras.Input(shape=(8,))\n", | |
"x = layers.Dense(64, activation='relu')(input)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"output = layers.Dense(1)(x)\n", | |
"\n", | |
"nn_1 = keras.Model(input, output)\n", | |
"nn_1.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "8VnC8Z7iPGFj", | |
"outputId": "3787a96e-9fa6-4506-df4a-c6d9c934f340" | |
}, | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Model: \"model\"\n", | |
"_________________________________________________________________\n", | |
" Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
" input_1 (InputLayer) [(None, 8)] 0 \n", | |
" \n", | |
" dense (Dense) (None, 64) 576 \n", | |
" \n", | |
" dense_1 (Dense) (None, 64) 4160 \n", | |
" \n", | |
" dense_2 (Dense) (None, 1) 65 \n", | |
" \n", | |
"=================================================================\n", | |
"Total params: 4,801\n", | |
"Trainable params: 4,801\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"LR = 0.001\n", | |
"nn_1.compile(\n", | |
" optimizer=keras.optimizers.Adam(learning_rate=LR),\n", | |
" loss=[tf.keras.losses.MeanAbsoluteError()],\n", | |
" metrics=[tf.keras.metrics.MeanSquaredError()]\n", | |
")\n", | |
"\n", | |
"print(\"Fit model on training data\")\n", | |
"history = nn_1.fit(\n", | |
" X_train,\n", | |
" y_train,\n", | |
" epochs=30,\n", | |
" validation_data=(X_test, y_test))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "YcDZg6DFXvZw", | |
"outputId": "f2975ecd-6b18-45c2-e4be-bab6d993dd0e" | |
}, | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Fit model on training data\n", | |
"Epoch 1/30\n", | |
"372/372 [==============================] - 5s 5ms/step - loss: 139482.4844 - mean_squared_error: 33668884480.0000 - val_loss: 116905.9531 - val_mean_squared_error: 27397017600.0000\n", | |
"Epoch 2/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 111370.8438 - mean_squared_error: 24605741056.0000 - val_loss: 108485.2969 - val_mean_squared_error: 24429299712.0000\n", | |
"Epoch 3/30\n", | |
"372/372 [==============================] - 3s 7ms/step - loss: 99567.7891 - mean_squared_error: 19507484672.0000 - val_loss: 94041.8203 - val_mean_squared_error: 17196939264.0000\n", | |
"Epoch 4/30\n", | |
"372/372 [==============================] - 3s 8ms/step - loss: 86368.7422 - mean_squared_error: 13820747776.0000 - val_loss: 84807.7422 - val_mean_squared_error: 12693773312.0000\n", | |
"Epoch 5/30\n", | |
"372/372 [==============================] - 3s 8ms/step - loss: 81420.1562 - mean_squared_error: 11497886720.0000 - val_loss: 82982.4219 - val_mean_squared_error: 11738802176.0000\n", | |
"Epoch 6/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 80465.3203 - mean_squared_error: 11095239680.0000 - val_loss: 82303.9922 - val_mean_squared_error: 11350008832.0000\n", | |
"Epoch 7/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 79616.7266 - mean_squared_error: 10756466688.0000 - val_loss: 81435.4219 - val_mean_squared_error: 11249816576.0000\n", | |
"Epoch 8/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 79001.1172 - mean_squared_error: 10591767552.0000 - val_loss: 80723.8125 - val_mean_squared_error: 11035149312.0000\n", | |
"Epoch 9/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 78283.8750 - mean_squared_error: 10401550336.0000 - val_loss: 80951.7109 - val_mean_squared_error: 10745340928.0000\n", | |
"Epoch 10/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 77595.8047 - mean_squared_error: 10213095424.0000 - val_loss: 79334.5234 - val_mean_squared_error: 10440878080.0000\n", | |
"Epoch 11/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76887.1719 - mean_squared_error: 9992031232.0000 - val_loss: 78739.1875 - val_mean_squared_error: 10250883072.0000\n", | |
"Epoch 12/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76198.5625 - mean_squared_error: 9840065536.0000 - val_loss: 77644.0000 - val_mean_squared_error: 10119945216.0000\n", | |
"Epoch 13/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 75468.0938 - mean_squared_error: 9653007360.0000 - val_loss: 77002.9844 - val_mean_squared_error: 9891603456.0000\n", | |
"Epoch 14/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 74665.4922 - mean_squared_error: 9469588480.0000 - val_loss: 76956.6953 - val_mean_squared_error: 9763158016.0000\n", | |
"Epoch 15/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 73971.3203 - mean_squared_error: 9307841536.0000 - val_loss: 75374.4844 - val_mean_squared_error: 9504031744.0000\n", | |
"Epoch 16/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 73166.8906 - mean_squared_error: 9104160768.0000 - val_loss: 75193.6016 - val_mean_squared_error: 9687987200.0000\n", | |
"Epoch 17/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 72396.7656 - mean_squared_error: 8948977664.0000 - val_loss: 73819.3516 - val_mean_squared_error: 9310675968.0000\n", | |
"Epoch 18/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 71540.7031 - mean_squared_error: 8752439296.0000 - val_loss: 73288.3984 - val_mean_squared_error: 8950458368.0000\n", | |
"Epoch 19/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 70802.9062 - mean_squared_error: 8597010432.0000 - val_loss: 72580.0391 - val_mean_squared_error: 8838232064.0000\n", | |
"Epoch 20/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 69980.5547 - mean_squared_error: 8406166528.0000 - val_loss: 71637.8906 - val_mean_squared_error: 8852182016.0000\n", | |
"Epoch 21/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 69242.8984 - mean_squared_error: 8236980736.0000 - val_loss: 70392.6641 - val_mean_squared_error: 8521859584.0000\n", | |
"Epoch 22/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 68497.3047 - mean_squared_error: 8084510208.0000 - val_loss: 71137.5312 - val_mean_squared_error: 8470008832.0000\n", | |
"Epoch 23/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 67920.0859 - mean_squared_error: 7970188288.0000 - val_loss: 68958.7109 - val_mean_squared_error: 8200147456.0000\n", | |
"Epoch 24/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 67173.3984 - mean_squared_error: 7818501120.0000 - val_loss: 68429.4219 - val_mean_squared_error: 8104857600.0000\n", | |
"Epoch 25/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 66582.7344 - mean_squared_error: 7707080192.0000 - val_loss: 67965.9844 - val_mean_squared_error: 7992566272.0000\n", | |
"Epoch 26/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 66125.7734 - mean_squared_error: 7616282112.0000 - val_loss: 67667.4922 - val_mean_squared_error: 7905951232.0000\n", | |
"Epoch 27/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 65612.4609 - mean_squared_error: 7541339648.0000 - val_loss: 66948.7656 - val_mean_squared_error: 7765702656.0000\n", | |
"Epoch 28/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 65159.6484 - mean_squared_error: 7475697664.0000 - val_loss: 66089.9922 - val_mean_squared_error: 7663987712.0000\n", | |
"Epoch 29/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 64645.4375 - mean_squared_error: 7347859968.0000 - val_loss: 66723.0781 - val_mean_squared_error: 7919130624.0000\n", | |
"Epoch 30/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 64289.3945 - mean_squared_error: 7301460480.0000 - val_loss: 65446.8633 - val_mean_squared_error: 7637944320.0000\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# list all data in history\n", | |
"print(history.history.keys())\n", | |
"# summarize history for accuracy\n", | |
"plt.plot(history.history['mean_squared_error'])\n", | |
"plt.plot(history.history['val_mean_squared_error'])\n", | |
"plt.title('model accuracy')\n", | |
"plt.ylabel('accuracy')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()\n", | |
"# summarize history for loss\n", | |
"plt.plot(history.history['loss'])\n", | |
"plt.plot(history.history['val_loss'])\n", | |
"plt.title('model loss')\n", | |
"plt.ylabel('loss')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 591 | |
}, | |
"id": "IoIffhtsc9o1", | |
"outputId": "b4e7fb25-91ed-4bd8-f791-dd6e61e56981" | |
}, | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error'])\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"X_train_norm, X_test_norm, y_train_norm, y_test_norm = train_test_split(X_normalized, y, test_size=0.3, random_state=2)\n", | |
"\n", | |
"input = keras.Input(shape=(8,))\n", | |
"x = layers.Dense(64, activation='relu')(input)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"output = layers.Dense(1)(x)\n", | |
"\n", | |
"nn_1 = keras.Model(input, output)\n", | |
"#nn_1.summary()\n", | |
"\n", | |
"LR = 0.001\n", | |
"nn_1.compile(\n", | |
" optimizer=keras.optimizers.Adam(learning_rate=LR),\n", | |
" loss=[tf.keras.losses.MeanAbsoluteError()],\n", | |
" metrics=[tf.keras.metrics.MeanSquaredError()]\n", | |
"\n", | |
")\n", | |
"\n", | |
"print(\"Fit model on training data\")\n", | |
"history = nn_1.fit(\n", | |
" X_train_norm,\n", | |
" y_train_norm,\n", | |
" epochs=30,\n", | |
" validation_data=(X_test_norm, y_test_norm))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Xgrc_ugHZlJq", | |
"outputId": "971dc864-4300-4a2d-db70-c676a6bdfb23" | |
}, | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Fit model on training data\n", | |
"Epoch 1/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 206795.4688 - mean_squared_error: 56014397440.0000 - val_loss: 205778.3906 - val_mean_squared_error: 56247169024.0000\n", | |
"Epoch 2/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 200349.9219 - mean_squared_error: 53364846592.0000 - val_loss: 192691.4375 - val_mean_squared_error: 50942058496.0000\n", | |
"Epoch 3/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 177556.3281 - mean_squared_error: 44675104768.0000 - val_loss: 159030.9219 - val_mean_squared_error: 38844985344.0000\n", | |
"Epoch 4/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 135406.7656 - mean_squared_error: 30624692224.0000 - val_loss: 114283.0156 - val_mean_squared_error: 24465571840.0000\n", | |
"Epoch 5/30\n", | |
"372/372 [==============================] - 2s 7ms/step - loss: 99154.9609 - mean_squared_error: 18977902592.0000 - val_loss: 91153.8203 - val_mean_squared_error: 16145519616.0000\n", | |
"Epoch 6/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 85705.9609 - mean_squared_error: 13909401600.0000 - val_loss: 86766.2344 - val_mean_squared_error: 13758028800.0000\n", | |
"Epoch 7/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 83946.8828 - mean_squared_error: 12788144128.0000 - val_loss: 86312.3516 - val_mean_squared_error: 13260951552.0000\n", | |
"Epoch 8/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 83658.9844 - mean_squared_error: 12504192000.0000 - val_loss: 86073.3828 - val_mean_squared_error: 13159665664.0000\n", | |
"Epoch 9/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 83417.0938 - mean_squared_error: 12436531200.0000 - val_loss: 85819.3906 - val_mean_squared_error: 13067907072.0000\n", | |
"Epoch 10/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 83161.7734 - mean_squared_error: 12325829632.0000 - val_loss: 85551.3047 - val_mean_squared_error: 12969872384.0000\n", | |
"Epoch 11/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 82888.3438 - mean_squared_error: 12262994944.0000 - val_loss: 85267.9609 - val_mean_squared_error: 12870999040.0000\n", | |
"Epoch 12/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 82610.1328 - mean_squared_error: 12163811328.0000 - val_loss: 84970.5938 - val_mean_squared_error: 12798576640.0000\n", | |
"Epoch 13/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 82309.5859 - mean_squared_error: 12090469376.0000 - val_loss: 84662.9531 - val_mean_squared_error: 12719804416.0000\n", | |
"Epoch 14/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 82001.8516 - mean_squared_error: 11967596544.0000 - val_loss: 84345.6406 - val_mean_squared_error: 12668781568.0000\n", | |
"Epoch 15/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 81691.1875 - mean_squared_error: 11925498880.0000 - val_loss: 84017.8125 - val_mean_squared_error: 12454376448.0000\n", | |
"Epoch 16/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 81364.3203 - mean_squared_error: 11785666560.0000 - val_loss: 83680.6406 - val_mean_squared_error: 12476708864.0000\n", | |
"Epoch 17/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 81035.7109 - mean_squared_error: 11714954240.0000 - val_loss: 83329.4141 - val_mean_squared_error: 12273813504.0000\n", | |
"Epoch 18/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 80690.4766 - mean_squared_error: 11618946048.0000 - val_loss: 82975.2422 - val_mean_squared_error: 12153061376.0000\n", | |
"Epoch 19/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 80342.5625 - mean_squared_error: 11458452480.0000 - val_loss: 82605.3125 - val_mean_squared_error: 12068451328.0000\n", | |
"Epoch 20/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 79982.4844 - mean_squared_error: 11394854912.0000 - val_loss: 82227.0781 - val_mean_squared_error: 11984166912.0000\n", | |
"Epoch 21/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 79606.8125 - mean_squared_error: 11279780864.0000 - val_loss: 81832.9531 - val_mean_squared_error: 11838595072.0000\n", | |
"Epoch 22/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 79217.4922 - mean_squared_error: 11152124928.0000 - val_loss: 81424.6328 - val_mean_squared_error: 11750178816.0000\n", | |
"Epoch 23/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 78818.5156 - mean_squared_error: 11080593408.0000 - val_loss: 81001.7969 - val_mean_squared_error: 11587097600.0000\n", | |
"Epoch 24/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 78403.0312 - mean_squared_error: 10902419456.0000 - val_loss: 80566.4219 - val_mean_squared_error: 11503527936.0000\n", | |
"Epoch 25/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 77971.4219 - mean_squared_error: 10838893568.0000 - val_loss: 80113.0391 - val_mean_squared_error: 11304907776.0000\n", | |
"Epoch 26/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 77523.2891 - mean_squared_error: 10658941952.0000 - val_loss: 79633.1875 - val_mean_squared_error: 11194773504.0000\n", | |
"Epoch 27/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 77053.9062 - mean_squared_error: 10558457856.0000 - val_loss: 79135.6172 - val_mean_squared_error: 11027051520.0000\n", | |
"Epoch 28/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76557.9375 - mean_squared_error: 10388887552.0000 - val_loss: 78609.7656 - val_mean_squared_error: 10929961984.0000\n", | |
"Epoch 29/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76051.3672 - mean_squared_error: 10264171520.0000 - val_loss: 78064.2422 - val_mean_squared_error: 10782563328.0000\n", | |
"Epoch 30/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 75515.6250 - mean_squared_error: 10129316864.0000 - val_loss: 77490.2656 - val_mean_squared_error: 10604675072.0000\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# list all data in history\n", | |
"print(history.history.keys())\n", | |
"# summarize history for accuracy\n", | |
"plt.plot(history.history['mean_squared_error'])\n", | |
"plt.plot(history.history['val_mean_squared_error'])\n", | |
"plt.title('model accuracy')\n", | |
"plt.ylabel('accuracy')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()\n", | |
"# summarize history for loss\n", | |
"plt.plot(history.history['loss'])\n", | |
"plt.plot(history.history['val_loss'])\n", | |
"plt.title('model loss')\n", | |
"plt.ylabel('loss')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 591 | |
}, | |
"id": "3YrJ2xZgeiUe", | |
"outputId": "0e003de6-f261-47db-c89a-3cc9ae564119" | |
}, | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error'])\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# https://www.tensorflow.org/api_docs/python/tf/keras/layers/Normalization\n", | |
"\n", | |
"input = keras.Input(shape=(8,))\n", | |
"x = layers.Normalization()(input)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"output = layers.Dense(1)(x)\n", | |
"\n", | |
"nn_2 = keras.Model(input, output)\n", | |
"nn_2.summary()\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sIxJUcbsZLxQ", | |
"outputId": "28ccc6cf-67a5-4d47-a295-816ab9fffaff" | |
}, | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Model: \"model_2\"\n", | |
"_________________________________________________________________\n", | |
" Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
" input_3 (InputLayer) [(None, 8)] 0 \n", | |
" \n", | |
" normalization (Normalizatio (None, 8) 17 \n", | |
" n) \n", | |
" \n", | |
" dense_6 (Dense) (None, 64) 576 \n", | |
" \n", | |
" dense_7 (Dense) (None, 64) 4160 \n", | |
" \n", | |
" dense_8 (Dense) (None, 1) 65 \n", | |
" \n", | |
"=================================================================\n", | |
"Total params: 4,818\n", | |
"Trainable params: 4,801\n", | |
"Non-trainable params: 17\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"LR = 0.001\n", | |
"nn_2.compile(\n", | |
" optimizer=keras.optimizers.Adam(learning_rate=LR),\n", | |
" loss=[tf.keras.losses.MeanAbsoluteError()],\n", | |
" metrics=[tf.keras.metrics.MeanSquaredError()]\n", | |
"\n", | |
")\n", | |
"\n", | |
"print(\"Fit model on training data\")\n", | |
"history = nn_2.fit(\n", | |
" X_train,\n", | |
" y_train,\n", | |
" epochs=30,\n", | |
" validation_data=(X_test, y_test))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0THrBDVnf6T3", | |
"outputId": "ddd4df87-28c1-4f1b-aefd-e89a123c277f" | |
}, | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Fit model on training data\n", | |
"Epoch 1/30\n", | |
"372/372 [==============================] - 3s 5ms/step - loss: 144987.7812 - mean_squared_error: 35822739456.0000 - val_loss: 116765.2109 - val_mean_squared_error: 27681896448.0000\n", | |
"Epoch 2/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 112227.6172 - mean_squared_error: 24842373120.0000 - val_loss: 110188.4219 - val_mean_squared_error: 24272132096.0000\n", | |
"Epoch 3/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 102708.3047 - mean_squared_error: 20724606976.0000 - val_loss: 97876.5781 - val_mean_squared_error: 19019096064.0000\n", | |
"Epoch 4/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 88996.7344 - mean_squared_error: 14882031616.0000 - val_loss: 86128.2422 - val_mean_squared_error: 13466316800.0000\n", | |
"Epoch 5/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 81907.7969 - mean_squared_error: 11681053696.0000 - val_loss: 83248.2266 - val_mean_squared_error: 11745435648.0000\n", | |
"Epoch 6/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 80561.3438 - mean_squared_error: 11091396608.0000 - val_loss: 82357.6172 - val_mean_squared_error: 11481886720.0000\n", | |
"Epoch 7/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 79836.7969 - mean_squared_error: 10822083584.0000 - val_loss: 81863.3359 - val_mean_squared_error: 11213162496.0000\n", | |
"Epoch 8/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 79075.3516 - mean_squared_error: 10595145728.0000 - val_loss: 80830.5781 - val_mean_squared_error: 10984462336.0000\n", | |
"Epoch 9/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 78431.5938 - mean_squared_error: 10399730688.0000 - val_loss: 80304.7188 - val_mean_squared_error: 10730515456.0000\n", | |
"Epoch 10/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 77628.9766 - mean_squared_error: 10205846528.0000 - val_loss: 79689.5312 - val_mean_squared_error: 10670451712.0000\n", | |
"Epoch 11/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76945.5391 - mean_squared_error: 10016186368.0000 - val_loss: 78680.8672 - val_mean_squared_error: 10284953600.0000\n", | |
"Epoch 12/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 76145.0469 - mean_squared_error: 9805711360.0000 - val_loss: 77772.7109 - val_mean_squared_error: 10197805056.0000\n", | |
"Epoch 13/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 75720.4062 - mean_squared_error: 9702811648.0000 - val_loss: 77072.8438 - val_mean_squared_error: 10036048896.0000\n", | |
"Epoch 14/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 74835.0000 - mean_squared_error: 9465024512.0000 - val_loss: 76226.7344 - val_mean_squared_error: 9817841664.0000\n", | |
"Epoch 15/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 74007.2266 - mean_squared_error: 9262068736.0000 - val_loss: 75532.3984 - val_mean_squared_error: 9636645888.0000\n", | |
"Epoch 16/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 73280.5625 - mean_squared_error: 9077710848.0000 - val_loss: 75950.6094 - val_mean_squared_error: 9874145280.0000\n", | |
"Epoch 17/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 72531.5156 - mean_squared_error: 8928563200.0000 - val_loss: 74895.6172 - val_mean_squared_error: 9232152576.0000\n", | |
"Epoch 18/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 71859.8203 - mean_squared_error: 8754529280.0000 - val_loss: 73252.5703 - val_mean_squared_error: 9070807040.0000\n", | |
"Epoch 19/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 71020.6172 - mean_squared_error: 8572276736.0000 - val_loss: 73646.8516 - val_mean_squared_error: 8964041728.0000\n", | |
"Epoch 20/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 70221.9531 - mean_squared_error: 8412365312.0000 - val_loss: 71759.7578 - val_mean_squared_error: 8829378560.0000\n", | |
"Epoch 21/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 69619.6797 - mean_squared_error: 8267113984.0000 - val_loss: 71125.3750 - val_mean_squared_error: 8510948352.0000\n", | |
"Epoch 22/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 68809.0703 - mean_squared_error: 8099959296.0000 - val_loss: 70721.3750 - val_mean_squared_error: 8593863680.0000\n", | |
"Epoch 23/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 68094.2812 - mean_squared_error: 7958014976.0000 - val_loss: 70390.0703 - val_mean_squared_error: 8355642368.0000\n", | |
"Epoch 24/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 67535.4062 - mean_squared_error: 7867279872.0000 - val_loss: 68796.2031 - val_mean_squared_error: 8227276288.0000\n", | |
"Epoch 25/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 66880.1406 - mean_squared_error: 7741731328.0000 - val_loss: 68784.9375 - val_mean_squared_error: 8112913920.0000\n", | |
"Epoch 26/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 66363.6172 - mean_squared_error: 7637346816.0000 - val_loss: 67678.5938 - val_mean_squared_error: 7986403840.0000\n", | |
"Epoch 27/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 65961.8672 - mean_squared_error: 7552273408.0000 - val_loss: 67385.1562 - val_mean_squared_error: 7926674432.0000\n", | |
"Epoch 28/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 65690.1406 - mean_squared_error: 7523852288.0000 - val_loss: 68142.6797 - val_mean_squared_error: 8206020096.0000\n", | |
"Epoch 29/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 65099.7734 - mean_squared_error: 7418447872.0000 - val_loss: 66803.6172 - val_mean_squared_error: 7650893824.0000\n", | |
"Epoch 30/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 64623.6250 - mean_squared_error: 7302447616.0000 - val_loss: 66301.5781 - val_mean_squared_error: 7767943168.0000\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# list all data in history\n", | |
"print(history.history.keys())\n", | |
"# summarize history for accuracy\n", | |
"plt.plot(history.history['mean_squared_error'])\n", | |
"plt.plot(history.history['val_mean_squared_error'])\n", | |
"plt.title('model accuracy')\n", | |
"plt.ylabel('accuracy')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()\n", | |
"# summarize history for loss\n", | |
"plt.plot(history.history['loss'])\n", | |
"plt.plot(history.history['val_loss'])\n", | |
"plt.title('model loss')\n", | |
"plt.ylabel('loss')\n", | |
"plt.xlabel('epoch')\n", | |
"plt.legend(['train', 'test'], loc='upper left')\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 591 | |
}, | |
"id": "IAnqOOy8gDw_", | |
"outputId": "d0ee1725-ad19-423e-c126-c5c581b6ccef" | |
}, | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"dict_keys(['loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error'])\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1 = pd.read_csv(\"california_housing_train.csv\")\n", | |
"dataset_1['median_house_value'] = dataset_1['median_house_value']/1000\n", | |
"dataset_1.describe()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 300 | |
}, | |
"id": "wQCaX_4KidvU", | |
"outputId": "06769e3f-66e7-4b2f-9038-357f59abd0a5" | |
}, | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
" longitude latitude housing_median_age total_rooms \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean -119.562108 35.625225 28.589353 2643.664412 \n", | |
"std 2.005166 2.137340 12.586937 2179.947071 \n", | |
"min -124.350000 32.540000 1.000000 2.000000 \n", | |
"25% -121.790000 33.930000 18.000000 1462.000000 \n", | |
"50% -118.490000 34.250000 29.000000 2127.000000 \n", | |
"75% -118.000000 37.720000 37.000000 3151.250000 \n", | |
"max -114.310000 41.950000 52.000000 37937.000000 \n", | |
"\n", | |
" total_bedrooms population households median_income \\\n", | |
"count 17000.000000 17000.000000 17000.000000 17000.000000 \n", | |
"mean 539.410824 1429.573941 501.221941 3.883578 \n", | |
"std 421.499452 1147.852959 384.520841 1.908157 \n", | |
"min 1.000000 3.000000 1.000000 0.499900 \n", | |
"25% 297.000000 790.000000 282.000000 2.566375 \n", | |
"50% 434.000000 1167.000000 409.000000 3.544600 \n", | |
"75% 648.250000 1721.000000 605.250000 4.767000 \n", | |
"max 6445.000000 35682.000000 6082.000000 15.000100 \n", | |
"\n", | |
" median_house_value \n", | |
"count 17000.000000 \n", | |
"mean 207.300912 \n", | |
"std 115.983764 \n", | |
"min 14.999000 \n", | |
"25% 119.400000 \n", | |
"50% 180.400000 \n", | |
"75% 265.000000 \n", | |
"max 500.001000 " | |
], | |
"text/html": [ | |
"\n", | |
" <div id=\"df-f5087c31-e76f-4ab9-ad0a-713d66fe9688\">\n", | |
" <div class=\"colab-df-container\">\n", | |
" <div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>longitude</th>\n", | |
" <th>latitude</th>\n", | |
" <th>housing_median_age</th>\n", | |
" <th>total_rooms</th>\n", | |
" <th>total_bedrooms</th>\n", | |
" <th>population</th>\n", | |
" <th>households</th>\n", | |
" <th>median_income</th>\n", | |
" <th>median_house_value</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>count</th>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" <td>17000.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>mean</th>\n", | |
" <td>-119.562108</td>\n", | |
" <td>35.625225</td>\n", | |
" <td>28.589353</td>\n", | |
" <td>2643.664412</td>\n", | |
" <td>539.410824</td>\n", | |
" <td>1429.573941</td>\n", | |
" <td>501.221941</td>\n", | |
" <td>3.883578</td>\n", | |
" <td>207.300912</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>std</th>\n", | |
" <td>2.005166</td>\n", | |
" <td>2.137340</td>\n", | |
" <td>12.586937</td>\n", | |
" <td>2179.947071</td>\n", | |
" <td>421.499452</td>\n", | |
" <td>1147.852959</td>\n", | |
" <td>384.520841</td>\n", | |
" <td>1.908157</td>\n", | |
" <td>115.983764</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>min</th>\n", | |
" <td>-124.350000</td>\n", | |
" <td>32.540000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>2.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>3.000000</td>\n", | |
" <td>1.000000</td>\n", | |
" <td>0.499900</td>\n", | |
" <td>14.999000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25%</th>\n", | |
" <td>-121.790000</td>\n", | |
" <td>33.930000</td>\n", | |
" <td>18.000000</td>\n", | |
" <td>1462.000000</td>\n", | |
" <td>297.000000</td>\n", | |
" <td>790.000000</td>\n", | |
" <td>282.000000</td>\n", | |
" <td>2.566375</td>\n", | |
" <td>119.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>50%</th>\n", | |
" <td>-118.490000</td>\n", | |
" <td>34.250000</td>\n", | |
" <td>29.000000</td>\n", | |
" <td>2127.000000</td>\n", | |
" <td>434.000000</td>\n", | |
" <td>1167.000000</td>\n", | |
" <td>409.000000</td>\n", | |
" <td>3.544600</td>\n", | |
" <td>180.400000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75%</th>\n", | |
" <td>-118.000000</td>\n", | |
" <td>37.720000</td>\n", | |
" <td>37.000000</td>\n", | |
" <td>3151.250000</td>\n", | |
" <td>648.250000</td>\n", | |
" <td>1721.000000</td>\n", | |
" <td>605.250000</td>\n", | |
" <td>4.767000</td>\n", | |
" <td>265.000000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>max</th>\n", | |
" <td>-114.310000</td>\n", | |
" <td>41.950000</td>\n", | |
" <td>52.000000</td>\n", | |
" <td>37937.000000</td>\n", | |
" <td>6445.000000</td>\n", | |
" <td>35682.000000</td>\n", | |
" <td>6082.000000</td>\n", | |
" <td>15.000100</td>\n", | |
" <td>500.001000</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>\n", | |
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-f5087c31-e76f-4ab9-ad0a-713d66fe9688')\"\n", | |
" title=\"Convert this dataframe to an interactive table.\"\n", | |
" style=\"display:none;\">\n", | |
" \n", | |
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", | |
" width=\"24px\">\n", | |
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n", | |
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n", | |
" </svg>\n", | |
" </button>\n", | |
" \n", | |
" <style>\n", | |
" .colab-df-container {\n", | |
" display:flex;\n", | |
" flex-wrap:wrap;\n", | |
" gap: 12px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert {\n", | |
" background-color: #E8F0FE;\n", | |
" border: none;\n", | |
" border-radius: 50%;\n", | |
" cursor: pointer;\n", | |
" display: none;\n", | |
" fill: #1967D2;\n", | |
" height: 32px;\n", | |
" padding: 0 0 0 0;\n", | |
" width: 32px;\n", | |
" }\n", | |
"\n", | |
" .colab-df-convert:hover {\n", | |
" background-color: #E2EBFA;\n", | |
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", | |
" fill: #174EA6;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert {\n", | |
" background-color: #3B4455;\n", | |
" fill: #D2E3FC;\n", | |
" }\n", | |
"\n", | |
" [theme=dark] .colab-df-convert:hover {\n", | |
" background-color: #434B5C;\n", | |
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", | |
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", | |
" fill: #FFFFFF;\n", | |
" }\n", | |
" </style>\n", | |
"\n", | |
" <script>\n", | |
" const buttonEl =\n", | |
" document.querySelector('#df-f5087c31-e76f-4ab9-ad0a-713d66fe9688 button.colab-df-convert');\n", | |
" buttonEl.style.display =\n", | |
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n", | |
"\n", | |
" async function convertToInteractive(key) {\n", | |
" const element = document.querySelector('#df-f5087c31-e76f-4ab9-ad0a-713d66fe9688');\n", | |
" const dataTable =\n", | |
" await google.colab.kernel.invokeFunction('convertToInteractive',\n", | |
" [key], {});\n", | |
" if (!dataTable) return;\n", | |
"\n", | |
" const docLinkHtml = 'Like what you see? Visit the ' +\n", | |
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", | |
" + ' to learn more about interactive tables.';\n", | |
" element.innerHTML = '';\n", | |
" dataTable['output_type'] = 'display_data';\n", | |
" await google.colab.output.renderOutput(dataTable, element);\n", | |
" const docLink = document.createElement('div');\n", | |
" docLink.innerHTML = docLinkHtml;\n", | |
" element.appendChild(docLink);\n", | |
" }\n", | |
" </script>\n", | |
" </div>\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 32 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dataset_1[\"rooms_per_person\"] = dataset_1[\"total_rooms\"] / dataset_1[\"population\"]\n", | |
"dataset_1[\"rooms_on_income\"] = dataset_1[\"median_income\"]*dataset_1[\"total_rooms\"]\n", | |
"dataset_1.corr()['median_house_value']" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "UsNhFAwajhZ7", | |
"outputId": "eb015fd8-d237-4861-b313-287272a30879" | |
}, | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"longitude -0.044982\n", | |
"latitude -0.144917\n", | |
"housing_median_age 0.106758\n", | |
"total_rooms 0.130991\n", | |
"total_bedrooms 0.045783\n", | |
"population -0.027850\n", | |
"households 0.061031\n", | |
"median_income 0.691871\n", | |
"median_house_value 1.000000\n", | |
"rooms_per_person 0.206969\n", | |
"rooms_on_income 0.375019\n", | |
"Name: median_house_value, dtype: float64" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 33 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#dataset_2 = dataset_1.drop(columns=[\"longitude\", \"total_bedrooms\", \"population\", \"households\"])\n", | |
"X = dataset_1.loc[:, dataset_1.columns != 'median_house_value'].to_numpy()\n", | |
"y = dataset_1['median_house_value'].to_numpy()\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=2)\n", | |
"print(\"Dims of train \", X_train.shape, y_train.shape)\n", | |
"print(\"Dims of test \", X_test.shape, y_test.shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "EHl0evTDktL7", | |
"outputId": "445bccf0-f26a-4f98-f1f4-f85a6f0b092a" | |
}, | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Dims of train (11900, 10) (11900,)\n", | |
"Dims of test (5100, 10) (5100,)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"input = keras.Input(shape=(10,))\n", | |
"x = layers.Normalization()(input)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"x = layers.Dense(64, activation='relu')(x)\n", | |
"output = layers.Dense(1)(x)\n", | |
"\n", | |
"nn_2 = keras.Model(input, output)\n", | |
"\n", | |
"LR = 0.001\n", | |
"nn_2.compile(\n", | |
" optimizer=keras.optimizers.Adam(learning_rate=LR),\n", | |
" loss=[tf.keras.losses.MeanAbsoluteError()],\n", | |
" metrics=[tf.keras.metrics.MeanSquaredError()]\n", | |
"\n", | |
")\n", | |
"\n", | |
"print(\"Fit model on training data\")\n", | |
"history = nn_2.fit(\n", | |
" X_train,\n", | |
" y_train,\n", | |
" epochs=30,\n", | |
" validation_data=(X_test, y_test))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wpsxHYMWlUmM", | |
"outputId": "2502eb9a-5b9c-48cd-cbc3-ddb268b4eee3" | |
}, | |
"execution_count": 35, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Fit model on training data\n", | |
"Epoch 1/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 123.5778 - mean_squared_error: 84509.2656 - val_loss: 94.7028 - val_mean_squared_error: 28951.8555\n", | |
"Epoch 2/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 74.1135 - mean_squared_error: 13103.9932 - val_loss: 84.0273 - val_mean_squared_error: 13924.6475\n", | |
"Epoch 3/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 69.8124 - mean_squared_error: 11254.4961 - val_loss: 62.4042 - val_mean_squared_error: 8254.2021\n", | |
"Epoch 4/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 63.7925 - mean_squared_error: 8883.7812 - val_loss: 65.3682 - val_mean_squared_error: 11510.4131\n", | |
"Epoch 5/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 64.2597 - mean_squared_error: 9500.2275 - val_loss: 57.6894 - val_mean_squared_error: 7735.1694\n", | |
"Epoch 6/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 60.9381 - mean_squared_error: 7677.7090 - val_loss: 58.1475 - val_mean_squared_error: 6879.2720\n", | |
"Epoch 7/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 62.2062 - mean_squared_error: 8341.0576 - val_loss: 60.1291 - val_mean_squared_error: 7726.6885\n", | |
"Epoch 8/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 60.3868 - mean_squared_error: 7539.1187 - val_loss: 63.7201 - val_mean_squared_error: 10191.7227\n", | |
"Epoch 9/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 60.5445 - mean_squared_error: 7589.5142 - val_loss: 62.1348 - val_mean_squared_error: 8403.5000\n", | |
"Epoch 10/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 59.5142 - mean_squared_error: 7382.1558 - val_loss: 57.4914 - val_mean_squared_error: 7134.7656\n", | |
"Epoch 11/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 59.4686 - mean_squared_error: 7449.5151 - val_loss: 52.8626 - val_mean_squared_error: 5637.7847\n", | |
"Epoch 12/30\n", | |
"372/372 [==============================] - 1s 4ms/step - loss: 55.0070 - mean_squared_error: 6172.5854 - val_loss: 65.5911 - val_mean_squared_error: 8472.5693\n", | |
"Epoch 13/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 58.2922 - mean_squared_error: 7076.7959 - val_loss: 63.0623 - val_mean_squared_error: 7742.8179\n", | |
"Epoch 14/30\n", | |
"372/372 [==============================] - 1s 4ms/step - loss: 56.5393 - mean_squared_error: 6516.9219 - val_loss: 58.3570 - val_mean_squared_error: 7519.5229\n", | |
"Epoch 15/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 54.7303 - mean_squared_error: 6120.8174 - val_loss: 57.7792 - val_mean_squared_error: 6764.9136\n", | |
"Epoch 16/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 55.3258 - mean_squared_error: 6324.9487 - val_loss: 52.5866 - val_mean_squared_error: 5486.2563\n", | |
"Epoch 17/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 56.8465 - mean_squared_error: 6713.0439 - val_loss: 52.5608 - val_mean_squared_error: 5543.4482\n", | |
"Epoch 18/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 55.0057 - mean_squared_error: 6202.8843 - val_loss: 63.2907 - val_mean_squared_error: 9807.6797\n", | |
"Epoch 19/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 54.4341 - mean_squared_error: 6005.2930 - val_loss: 52.2481 - val_mean_squared_error: 5230.1445\n", | |
"Epoch 20/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 52.3851 - mean_squared_error: 5476.8960 - val_loss: 50.6791 - val_mean_squared_error: 5269.7285\n", | |
"Epoch 21/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 54.1140 - mean_squared_error: 6032.5000 - val_loss: 52.0576 - val_mean_squared_error: 5089.3145\n", | |
"Epoch 22/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 51.5882 - mean_squared_error: 5353.1353 - val_loss: 54.0523 - val_mean_squared_error: 5410.7358\n", | |
"Epoch 23/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 54.3766 - mean_squared_error: 6009.6519 - val_loss: 51.8409 - val_mean_squared_error: 5514.4302\n", | |
"Epoch 24/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 51.8519 - mean_squared_error: 5430.7134 - val_loss: 50.4292 - val_mean_squared_error: 5203.2998\n", | |
"Epoch 25/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 51.6805 - mean_squared_error: 5387.0952 - val_loss: 54.4006 - val_mean_squared_error: 5912.8208\n", | |
"Epoch 26/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 50.6099 - mean_squared_error: 5249.1851 - val_loss: 52.4657 - val_mean_squared_error: 5114.3223\n", | |
"Epoch 27/30\n", | |
"372/372 [==============================] - 2s 4ms/step - loss: 51.1354 - mean_squared_error: 5322.7368 - val_loss: 50.4774 - val_mean_squared_error: 4757.1084\n", | |
"Epoch 28/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 50.6150 - mean_squared_error: 5105.1929 - val_loss: 55.7018 - val_mean_squared_error: 6655.2378\n", | |
"Epoch 29/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 50.2288 - mean_squared_error: 5055.5469 - val_loss: 49.8351 - val_mean_squared_error: 4791.1860\n", | |
"Epoch 30/30\n", | |
"372/372 [==============================] - 2s 5ms/step - loss: 50.5648 - mean_squared_error: 5132.6577 - val_loss: 49.7307 - val_mean_squared_error: 5083.0269\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "AcvusFN2lyP0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment