Skip to content

Instantly share code, notes, and snippets.

@dcolinmorgan
Created July 7, 2022 09:52
Show Gist options
  • Save dcolinmorgan/5b8fcbbeb0405d4cf501610d19dc6395 to your computer and use it in GitHub Desktop.
Save dcolinmorgan/5b8fcbbeb0405d4cf501610d19dc6395 to your computer and use it in GitHub Desktop.
wrangle_classifier_demo.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "wrangle_classifier_demo.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyNB4Susi1JVmSLW4VVZ+dwQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/dcolinmorgan/5b8fcbbeb0405d4cf501610d19dc6395/wrangle_classifier_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "y4RYn54E-pbG"
},
"outputs": [],
"source": [
"from sklearn.metrics import mean_absolute_error,r2_score,accuracy_score\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from sklearn.linear_model import SGDClassifier\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from lightgbm import LGBMClassifier\n",
"from xgboost.sklearn import XGBClassifier\n",
"\n",
"import seaborn as sns,numpy as np, pandas as pd,glob,os\n",
"from geopy.geocoders import Nominatim\n",
"import geopy.distance"
]
},
{
"cell_type": "markdown",
"source": [
"## preprocessing\n",
"\n",
"---\n",
"\n",
"1. df--> health data\n",
"2. data(2) --> pollution data\n",
"\n"
],
"metadata": {
"id": "fE4zBvDq__SS"
}
},
{
"cell_type": "code",
"source": [
"!wget https://github.com/dcolinmorgan/AXA_AE_app/raw/main/axa_p.zip\n",
"!7z x /content/axa_p.zip\n",
"!git clone https://github.com/dcolinmorgan/aqi-stations-scraper.git\n",
"\n",
"df2=pd.read_parquet('/content/AE_AXA_dat_full.parquet')\n",
"\n",
"df2.columns=['pat_id','cd9_loc','sess','sex','age','cd9_code','mini_loc','loc1','date','tmp','diag1','diag2','tmp']\n",
"\n",
"\n",
"df2.replace({'RH':'Ruttonjee Hospital'},inplace=True)\n",
"df2.replace({'PYN':'Pamela Youde Nethersole Eastern Hospital'},inplace=True)\n",
"df2.replace({'QEH':'Queen Elizabeth Hospital'},inplace=True)\n",
"df2.replace({'CMC':'Caritas Medical Centre'},inplace=True)\n",
"df2.replace({'KWH':'Kwong Wah Hospital'},inplace=True)\n",
"df2.replace({'TMH':'Tuen Mun Hospital'},inplace=True)\n",
"df2.replace({'PWH':'Prince of Wales Hospital'},inplace=True)\n",
"df2.replace({'NDH':'North District Hospital'},inplace=True)\n",
"df2.replace({'YCH':'Yan Chai Hospital'},inplace=True)\n",
"df2.replace({'UCH':'United Christian Hospital'},inplace=True)\n",
"df2.replace({'QMH':'Queen Mary Hospital'},inplace=True)\n",
"df2.replace({'PWH':'Princess Margaret Hospital'},inplace=True)\n",
"df2.replace({'POH':'Pok Oi Hospital'},inplace=True)\n",
"df2.replace({'TKO':'Tseung Kwan O Hospital'},inplace=True)\n",
"df2.replace({'AHN':'Alice Ho Miu Ling Nethersole Hospital'},inplace=True)\n",
"df2.replace({'SJH':'St. John Hospital'},inplace=True)\n",
"df2.replace({'NLT':'North Lantau Hospital'},inplace=True)\n",
"df2.replace({'TSH':'Tang Shiu Kin Hospital'},inplace=True)\n",
"df2.replace({'PMH':'Princess Margaret Hospital'},inplace=True)\n",
"\n",
"\n",
"#organize\n",
"cc=pd.DataFrame()\n",
"files=glob.glob('/content/aqi-stations-scraper/data/japan-aqi/*')\n",
"for file in files:\n",
" data=pd.read_csv(file,sep=' |,')\n",
" data['loc1']=os.path.basename(file).split(',')[0]\n",
" cc=cc.append(data)\n",
"\n",
"data2=cc[['date','pm25','pm10','o3','no2','so2','co','loc1']]\n",
"data2['loc1']=data2['loc1'].str.upper().replace({'-':' '},regex=True)\n",
"data2['date']=pd.to_datetime(data2['date'])\n",
"\n",
"geolocator = Nominatim(user_agent=\"example app\")\n",
"df_loc=pd.DataFrame(columns=['lat','long','name'])\n",
"for ii,i in enumerate(pd.unique(df2['cd9_loc'])):\n",
" a,b,c=geolocator.geocode(str(i)+\", Hong Kong\").point\n",
" df_loc[ii]=[a,b,i]\n",
"df_loc=df_loc.transpose()\n",
"df_loc.columns=['lat','long','name']\n",
"df_loc=df_loc[3:]\n",
" "
],
"metadata": {
"id": "qlvaQ8g5_-gk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data2.replace('CENTRALNAYA STR','central',inplace=True)\n",
"data2.replace('SOUTHERN','southern island',inplace=True)\n",
"data2.replace('SOUTHERN PART OF CHENGYANG DISTRICT','chengyang district',inplace=True)\n",
"\n",
"data_loc=pd.DataFrame(columns=['lat','long','name'])\n",
"for ii,i in enumerate(pd.unique(data2['loc1'])):\n",
" try:\n",
" a,b,c=geolocator.geocode(str(i)+\", Hong Kong\").point\n",
" except AttributeError:\n",
" print('no location data for: '+str(i))\n",
" data_loc[ii]=[a,b,i]\n",
"data_loc=data_loc.transpose()\n",
"data_loc.columns=['lat','long','name']\n",
"data_loc=data_loc[3:]\n",
"\n",
"data_loc=data_loc[~data_loc.duplicated(['lat','long'],keep='first')]\n",
"data_loc.reset_index(inplace=True)\n",
"\n",
"data_loc=df_loc.append(data_loc)[['lat','long','name']]\n",
"2\n",
"data_loc.reset_index(inplace=True)\n",
"\n",
"\n",
"# geopy DOES use latlon configuration\n",
"data_loc['latlon'] = list(zip(data_loc['lat'], data_loc['long']))\n",
"square = pd.DataFrame(\n",
" np.zeros((data_loc.shape[0], data_loc.shape[0])),\n",
" index=data_loc.index, columns=data_loc.index\n",
")\n",
"\n",
"# replacing distance.vicenty with distance.distance\n",
"def get_distance(col):\n",
" end = data_loc.loc[col.name, 'latlon']\n",
" return data_loc['latlon'].apply(geopy.distance.distance,\n",
" args=(end,),\n",
" ellipsoid='WGS-84'\n",
" )\n",
"\n",
"distances = square.apply(get_distance, axis=1).T\n",
"\n",
"data_loc['src']=data_loc['name']\n",
"data_loc['dst']=data_loc['name']\n",
"\n",
"# np.sum((distances<5)*1)\n",
"D_D=pd.DataFrame((distances<5)*1)\n",
"D_D.index=data_loc['src']\n",
"D_D.columns=data_loc['dst']\n",
"\n",
"E_E=pd.DataFrame(D_D.stack())#.reset_index(inplace=True)\n",
"E_E.reset_index(inplace=True)#\n",
"distance_mat=E_E[E_E[0]>0]\n",
"\n",
"distance=distances\n",
"distance.index=data_loc['src']\n",
"distance.columns=data_loc['dst']\n",
"distance=pd.DataFrame(distance.stack())\n",
"distance.reset_index(inplace=True)\n",
"\n",
"#prepare for TF\n",
"\n",
"distances=distances.astype(str)\n",
"distances=distances.replace('km', '', regex=True)\n",
"distances=distances.astype(np.float64)\n",
"\n",
"distances.to_numpy()"
],
"metadata": {
"id": "N8poYsIf_-jH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## post-process"
],
"metadata": {
"id": "oL_WaMT3BL1q"
}
},
{
"cell_type": "code",
"source": [
"!wget https://github.com/dcolinmorgan/AXA_AE_app/raw/main/coord_fulldiag_UVI_min.npy.gz\n",
"!7z x /content/coord_fulldiag_UVI_min.npy.gz\n",
"\n",
"data5=np.load('/content/coord_fulldiag_UVI_min.npy',allow_pickle=True)\n",
"data5=pd.DataFrame(data5,columns=['pm25', 'pm10', 'o3', 'no2', 'so2', 'co', 'lat',\n",
" 'long', 'name', 'year', 'week', 'diag1', 'UVI'])\n",
"\n",
"data5['weekA']=data5['week']+(52*(data5['year']-np.nanmin(data5['year'])))"
],
"metadata": {
"id": "lnmxHgVx-1B7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"MATCHED_DATA=data5.groupby(['pm25','pm10','o3','no2','so2','co','lat','long','UVI','weekA']).agg('count')\n",
"MATCHED_DATA.reset_index(inplace=True)\n",
"y = MATCHED_DATA.diag1\n",
"MATCHED_DATA.drop(['diag1'], axis=1, inplace=True)\n",
"X = MATCHED_DATA[['pm25','pm10','o3','no2','so2','co','UVI','weekA']]\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=2)"
],
"metadata": {
"id": "buSXrQ2E-1EC"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"SGD= SGDClassifier()\n",
"SGD.fit(X_train,y_train)\n",
"\n",
"KNN= KNeighborsClassifier()\n",
"KNN.fit(X_train,y_train)\n",
"\n",
"DTC= DecisionTreeClassifier()\n",
"DTC.fit(X_train,y_train)\n",
"\n",
"RFC= RandomForestClassifier()\n",
"RFC.fit(X_train,y_train)\n",
"\n",
"LGBM= LGBMClassifier()\n",
"LGBM.fit(X_train,y_train)\n",
" \n",
"XGB= XGBClassifier()\n",
"XGB.fit(X_train,y_train)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "knfC6CPw-4Td",
"outputId": "a56dd489-c7c0-4ce5-ea3f-e2440bc8392c"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"XGBClassifier(objective='multi:softprob')"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"results=pd.DataFrame(index=['ACC','MAE','R2'])\n",
"\n",
"predA=SGD.predict(X_test)\n",
"results['SGC']=[accuracy_score(y_test,predA),mean_absolute_error(y_test,predA),r2_score(y_test,predA)]\n",
"predB=KNN.predict(X_test)\n",
"results['KNN']=[accuracy_score(y_test,predB),mean_absolute_error(y_test,predB),r2_score(y_test,predB)]\n",
"predC=DTC.predict(X_test)\n",
"results['DTC']=[accuracy_score(y_test,predC),mean_absolute_error(y_test,predC),r2_score(y_test,predC)]\n",
"predD=RFC.predict(X_test)\n",
"results['RFC']=[accuracy_score(y_test,predD),mean_absolute_error(y_test,predD),r2_score(y_test,predD)]\n",
"predE=LGBM.predict(X_test)\n",
"results['LGBM']=[accuracy_score(y_test,predE),mean_absolute_error(y_test,predE),r2_score(y_test,predE)]\n",
"predF=XGB.predict(X_test)\n",
"results['XGB']=[accuracy_score(y_test,predF),mean_absolute_error(y_test,predF),r2_score(y_test,predF)]\n",
"results"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"id": "5a1hzBti-4V9",
"outputId": "69487cb1-7601-406e-d120-ca81d6d8883a"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" SGC KNN DTC RFC LGBM XGB\n",
"ACC 0.034091 0.151515 0.189394 0.250000 0.238636 0.219697\n",
"MAE 7.905303 3.073864 2.744318 2.446970 2.325758 2.585227\n",
"R2 -5.371673 -0.383712 -0.076455 0.044189 0.115882 -0.023739"
],
"text/html": [
"\n",
" <div id=\"df-36961955-48ae-47cc-b86a-d821c4fdd341\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>SGC</th>\n",
" <th>KNN</th>\n",
" <th>DTC</th>\n",
" <th>RFC</th>\n",
" <th>LGBM</th>\n",
" <th>XGB</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>ACC</th>\n",
" <td>0.034091</td>\n",
" <td>0.151515</td>\n",
" <td>0.189394</td>\n",
" <td>0.250000</td>\n",
" <td>0.238636</td>\n",
" <td>0.219697</td>\n",
" </tr>\n",
" <tr>\n",
" <th>MAE</th>\n",
" <td>7.905303</td>\n",
" <td>3.073864</td>\n",
" <td>2.744318</td>\n",
" <td>2.446970</td>\n",
" <td>2.325758</td>\n",
" <td>2.585227</td>\n",
" </tr>\n",
" <tr>\n",
" <th>R2</th>\n",
" <td>-5.371673</td>\n",
" <td>-0.383712</td>\n",
" <td>-0.076455</td>\n",
" <td>0.044189</td>\n",
" <td>0.115882</td>\n",
" <td>-0.023739</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-36961955-48ae-47cc-b86a-d821c4fdd341')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-36961955-48ae-47cc-b86a-d821c4fdd341 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-36961955-48ae-47cc-b86a-d821c4fdd341');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"sns.set_theme(style=\"whitegrid\")\n",
"g = sns.catplot(\n",
" data=results.melt(ignore_index=False).reset_index(), kind=\"bar\",\n",
" y=\"value\", x=\"index\", hue=\"variable\",\n",
" palette=\"dark\", alpha=.6, height=6,aspect=1.7\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 437
},
"id": "9ruYuUaL-9AD",
"outputId": "7cefdd47-65a0-4048-bb65-60d9577df904"
},
"execution_count": 6,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 805.625x432 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "iYsSBrnB_EPc"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment