Skip to content

Instantly share code, notes, and snippets.

@kshirsagarsiddharth
Created December 10, 2019 13:15
Show Gist options
  • Select an option

  • Save kshirsagarsiddharth/17812a747d5df1cbce247cb88b083cc7 to your computer and use it in GitHub Desktop.

Select an option

Save kshirsagarsiddharth/17812a747d5df1cbce247cb88b083cc7 to your computer and use it in GitHub Desktop.
Created on Cognitive Class Labs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as ticker\n",
"from sklearn.datasets import load_boston\n",
"from sklearn import preprocessing\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.linear_model import LinearRegression,Ridge\n",
"from sklearn.metrics import r2_score,mean_squared_error"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is now loaded subsequently the features should be standardized.Since ridge regression shrinks coefficients by penalizing,The features should be scaled for start condition to be fair."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"house_price = load_boston()\n",
"df = pd.DataFrame(house_price.data,columns=house_price.feature_names)\n",
"df['PRICE'] = house_price.target\n",
"#standardize and train test split\n",
"house_price.data = preprocessing.scale(house_price.data)\n",
"X_train,X_test,y_train,y_test = train_test_split(house_price.data,house_price.target,test_size=0.3,random_state=10)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"now we can iterate the lambda values ranged from 0 to 199.Note that the coefficients at lambda equal to zero are same with OLS coefficients"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"ridge_reg = Ridge(alpha=0)\n",
"ridge_reg.fit(X_train,y_train)\n",
"ridge_df = pd.DataFrame({'variables':house_price.feature_names,'estimate':ridge_reg.coef_})\n",
"ridge_train_pred = []\n",
"ridge_test_pred = []\n",
"ols_pred = ridge_reg.predict(X_test)\n",
"for alpha in np.arange(0,200,1):\n",
" ridge_reg = Ridge(alpha=alpha)\n",
" ridge_reg.fit(X_train,y_train)\n",
" var_name = 'estimate' + str(alpha)\n",
" ridge_df[var_name] = ridge_reg.coef_\n",
" ridge_train_pred.append(ridge_reg.predict(X_train))\n",
" ridge_test_pred.append(ridge_reg.predict(X_test))\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>variables</th>\n",
" <th>estimate</th>\n",
" <th>CRIM</th>\n",
" <th>ZN</th>\n",
" <th>INDUS</th>\n",
" <th>CHAS</th>\n",
" <th>NOX</th>\n",
" <th>RM</th>\n",
" <th>AGE</th>\n",
" <th>DIS</th>\n",
" <th>RAD</th>\n",
" <th>TAX</th>\n",
" <th>PTRATIO</th>\n",
" <th>B</th>\n",
" <th>LSTAT</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>estimate</td>\n",
" <td>-1.321404</td>\n",
" <td>1.514832</td>\n",
" <td>-0.166266</td>\n",
" <td>0.411579</td>\n",
" <td>-1.771168</td>\n",
" <td>2.352821</td>\n",
" <td>0.318499</td>\n",
" <td>-3.256645</td>\n",
" <td>2.632576</td>\n",
" <td>-2.059466</td>\n",
" <td>-1.755201</td>\n",
" <td>1.181143</td>\n",
" <td>-3.887043</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>estimate0</td>\n",
" <td>-1.321404</td>\n",
" <td>1.514832</td>\n",
" <td>-0.166266</td>\n",
" <td>0.411579</td>\n",
" <td>-1.771168</td>\n",
" <td>2.352821</td>\n",
" <td>0.318499</td>\n",
" <td>-3.256645</td>\n",
" <td>2.632576</td>\n",
" <td>-2.059466</td>\n",
" <td>-1.755201</td>\n",
" <td>1.181143</td>\n",
" <td>-3.887043</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>estimate1</td>\n",
" <td>-1.305844</td>\n",
" <td>1.491704</td>\n",
" <td>-0.186300</td>\n",
" <td>0.416554</td>\n",
" <td>-1.734866</td>\n",
" <td>2.368304</td>\n",
" <td>0.305241</td>\n",
" <td>-3.213459</td>\n",
" <td>2.540705</td>\n",
" <td>-1.976631</td>\n",
" <td>-1.746362</td>\n",
" <td>1.179751</td>\n",
" <td>-3.867286</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>estimate2</td>\n",
" <td>-1.291053</td>\n",
" <td>1.469716</td>\n",
" <td>-0.204508</td>\n",
" <td>0.421192</td>\n",
" <td>-1.700188</td>\n",
" <td>2.382762</td>\n",
" <td>0.292616</td>\n",
" <td>-3.171513</td>\n",
" <td>2.455276</td>\n",
" <td>-1.900395</td>\n",
" <td>-1.737941</td>\n",
" <td>1.178322</td>\n",
" <td>-3.847925</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>estimate3</td>\n",
" <td>-1.276962</td>\n",
" <td>1.448771</td>\n",
" <td>-0.221104</td>\n",
" <td>0.425531</td>\n",
" <td>-1.667018</td>\n",
" <td>2.396278</td>\n",
" <td>0.280573</td>\n",
" <td>-3.130750</td>\n",
" <td>2.375608</td>\n",
" <td>-1.830033</td>\n",
" <td>-1.729900</td>\n",
" <td>1.176856</td>\n",
" <td>-3.828944</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>196</th>\n",
" <td>estimate195</td>\n",
" <td>-0.660784</td>\n",
" <td>0.677185</td>\n",
" <td>-0.504568</td>\n",
" <td>0.499952</td>\n",
" <td>-0.498848</td>\n",
" <td>2.188311</td>\n",
" <td>-0.201508</td>\n",
" <td>-0.898071</td>\n",
" <td>0.183866</td>\n",
" <td>-0.495208</td>\n",
" <td>-1.241954</td>\n",
" <td>0.828847</td>\n",
" <td>-2.302692</td>\n",
" </tr>\n",
" <tr>\n",
" <th>197</th>\n",
" <td>estimate196</td>\n",
" <td>-0.659786</td>\n",
" <td>0.676331</td>\n",
" <td>-0.504713</td>\n",
" <td>0.499619</td>\n",
" <td>-0.497990</td>\n",
" <td>2.185530</td>\n",
" <td>-0.202161</td>\n",
" <td>-0.894260</td>\n",
" <td>0.181707</td>\n",
" <td>-0.494909</td>\n",
" <td>-1.240543</td>\n",
" <td>0.827560</td>\n",
" <td>-2.298731</td>\n",
" </tr>\n",
" <tr>\n",
" <th>198</th>\n",
" <td>estimate197</td>\n",
" <td>-0.658794</td>\n",
" <td>0.675484</td>\n",
" <td>-0.504856</td>\n",
" <td>0.499284</td>\n",
" <td>-0.497142</td>\n",
" <td>2.182753</td>\n",
" <td>-0.202809</td>\n",
" <td>-0.890472</td>\n",
" <td>0.179565</td>\n",
" <td>-0.494614</td>\n",
" <td>-1.239137</td>\n",
" <td>0.826278</td>\n",
" <td>-2.294789</td>\n",
" </tr>\n",
" <tr>\n",
" <th>199</th>\n",
" <td>estimate198</td>\n",
" <td>-0.657808</td>\n",
" <td>0.674643</td>\n",
" <td>-0.504998</td>\n",
" <td>0.498948</td>\n",
" <td>-0.496303</td>\n",
" <td>2.179981</td>\n",
" <td>-0.203453</td>\n",
" <td>-0.886709</td>\n",
" <td>0.177438</td>\n",
" <td>-0.494324</td>\n",
" <td>-1.237735</td>\n",
" <td>0.825002</td>\n",
" <td>-2.290864</td>\n",
" </tr>\n",
" <tr>\n",
" <th>200</th>\n",
" <td>estimate199</td>\n",
" <td>-0.656827</td>\n",
" <td>0.673807</td>\n",
" <td>-0.505138</td>\n",
" <td>0.498611</td>\n",
" <td>-0.495473</td>\n",
" <td>2.177215</td>\n",
" <td>-0.204092</td>\n",
" <td>-0.882969</td>\n",
" <td>0.175327</td>\n",
" <td>-0.494038</td>\n",
" <td>-1.236337</td>\n",
" <td>0.823731</td>\n",
" <td>-2.286957</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>201 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
"variables estimate CRIM ZN INDUS CHAS NOX \\\n",
"0 estimate -1.321404 1.514832 -0.166266 0.411579 -1.771168 \n",
"1 estimate0 -1.321404 1.514832 -0.166266 0.411579 -1.771168 \n",
"2 estimate1 -1.305844 1.491704 -0.186300 0.416554 -1.734866 \n",
"3 estimate2 -1.291053 1.469716 -0.204508 0.421192 -1.700188 \n",
"4 estimate3 -1.276962 1.448771 -0.221104 0.425531 -1.667018 \n",
".. ... ... ... ... ... ... \n",
"196 estimate195 -0.660784 0.677185 -0.504568 0.499952 -0.498848 \n",
"197 estimate196 -0.659786 0.676331 -0.504713 0.499619 -0.497990 \n",
"198 estimate197 -0.658794 0.675484 -0.504856 0.499284 -0.497142 \n",
"199 estimate198 -0.657808 0.674643 -0.504998 0.498948 -0.496303 \n",
"200 estimate199 -0.656827 0.673807 -0.505138 0.498611 -0.495473 \n",
"\n",
"variables RM AGE DIS RAD TAX PTRATIO \\\n",
"0 2.352821 0.318499 -3.256645 2.632576 -2.059466 -1.755201 \n",
"1 2.352821 0.318499 -3.256645 2.632576 -2.059466 -1.755201 \n",
"2 2.368304 0.305241 -3.213459 2.540705 -1.976631 -1.746362 \n",
"3 2.382762 0.292616 -3.171513 2.455276 -1.900395 -1.737941 \n",
"4 2.396278 0.280573 -3.130750 2.375608 -1.830033 -1.729900 \n",
".. ... ... ... ... ... ... \n",
"196 2.188311 -0.201508 -0.898071 0.183866 -0.495208 -1.241954 \n",
"197 2.185530 -0.202161 -0.894260 0.181707 -0.494909 -1.240543 \n",
"198 2.182753 -0.202809 -0.890472 0.179565 -0.494614 -1.239137 \n",
"199 2.179981 -0.203453 -0.886709 0.177438 -0.494324 -1.237735 \n",
"200 2.177215 -0.204092 -0.882969 0.175327 -0.494038 -1.236337 \n",
"\n",
"variables B LSTAT \n",
"0 1.181143 -3.887043 \n",
"1 1.181143 -3.887043 \n",
"2 1.179751 -3.867286 \n",
"3 1.178322 -3.847925 \n",
"4 1.176856 -3.828944 \n",
".. ... ... \n",
"196 0.828847 -2.302692 \n",
"197 0.827560 -2.298731 \n",
"198 0.826278 -2.294789 \n",
"199 0.825002 -2.290864 \n",
"200 0.823731 -2.286957 \n",
"\n",
"[201 rows x 14 columns]"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ridge_df = ridge_df.set_index('variables').T.rename_axis('estimate').rename_axis().reset_index()\n",
"ridge_df\n"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig,ax = plt.subplots(figsize=(10,5))\n",
"ax.plot(ridge_df.RM,'r',ridge_df.ZN,'g',ridge_df.RAD,'b',ridge_df.CRIM,'c',ridge_df.TAX,'y')\n",
"ax.axhline(y=0,color='black',linestyle='--')\n",
"ax.set_xlabel(\"Lambda\")\n",
"ax.set_ylabel(\"Beta Estimate\")\n",
"ax.set_title(\"Ridge Regression Trace\",fontsize=16)\n",
"ax.legend(labels=['Room','Residential zone','Highway Access','Crime Rate','Tax'])\n",
"ax.grid(True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"as we increase value of lambda the values of all the other estimates tend towards zero but rooms is not deviated which tells us that room is significant in determining the price of house\n"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"ridge_mse_test = [mean_squared_error(y_test,p) for p in ridge_test_pred]\n",
"ols_mse = mean_squared_error(y_test,ols_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'MSE')"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# MSE of Ridge and OLS\n",
"ridge_mse_test = [mean_squared_error(y_test, p) for p in ridge_test_pred]\n",
"ols_mse = mean_squared_error(y_test, ols_pred)\n",
"\n",
"# plot mse\n",
"plt.plot(ridge_mse_test[:25], 'ro')\n",
"plt.axhline(y=ols_mse, color='g', linestyle='--')\n",
"plt.title(\"Ridge Test Set MSE\", fontsize=16)\n",
"plt.xlabel(\"Model Simplicity$\\longrightarrow$\")\n",
"plt.ylabel(\"MSE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The green dotted line is from OLS on the graph above with the X-axis being drawn by increasing lambda values which means the model prediction is improved (less error to a certain point) In short an OLS model with some bias is better at prediction then the pure OLS model"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python",
"language": "python",
"name": "conda-env-python-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment