Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save abegehr/0d74bcfbc3726b6cb39270a49ee5ed3b to your computer and use it in GitHub Desktop.
Save abegehr/0d74bcfbc3726b6cb39270a49ee5ed3b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# COVID19 week4 SIR-Model (LeoCorona)\n",
"https://www.kaggle.com/c/covid19-global-forecasting-week-4"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Eo9ulW5s6iz5"
},
"outputs": [],
"source": [
"from sklearn.metrics import mean_squared_log_error\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"import pandas as pd\n",
"from sklearn import preprocessing\n",
"import numpy as np\n",
"from scipy import integrate, optimize\n",
"import math\n",
"\n",
"predictions_total = []\n",
"actual_total = []\n",
"val_loss_dict = {}\n",
"\n",
"val_info_dict = {}\n",
"predictions_dict = {}\n",
"actuals_dict = {}\n",
"colors_dict = {}\n",
"loss_dict = {}\n",
"train_start = 0\n",
"train_end = 0\n",
"val_start = 0\n",
"val_end = 0\n",
"test_start = 0\n",
"test_end = 0\n",
"modes = [\"Confirmed Cases\", \"Fatalities\"]\n",
"method = \"SIR\"\n",
"dynamic_start_day = False"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "bL4IrSy_-xqw"
},
"outputs": [],
"source": [
"test = pd.read_csv(\"../input/covid19-global-forecasting-week-4/test.csv\", parse_dates=[\"Date\"])\n",
"train = pd.read_csv(\"../input/covid19-global-forecasting-week-4/train.csv\")\n",
"submission = pd.read_csv(\"../input/covid19-global-forecasting-week-4/submission.csv\")\n",
"all_data = train.copy()\n",
"# Create date columns\n",
"all_data['Date'] = pd.to_datetime(all_data['Date'])\n",
"le = preprocessing.LabelEncoder()\n",
"all_data['Day_num'] = le.fit_transform(all_data.Date)\n",
"all_data['Day'] = all_data['Date'].dt.day\n",
"all_data['Month'] = all_data['Date'].dt.month\n",
"all_data['Year'] = all_data['Date'].dt.year"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "MEX_1h2CYUyZ",
"outputId": "7bef57b8-18b5-42a3-8518-069f9849dc4b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cleaned country details dataset\n",
"Joined dataset\n",
"Encoded dataset\n"
]
}
],
"source": [
"# Load countries data file\n",
"world_population = pd.read_csv(\"../input/population-by-country-2020/population_by_country_2020.csv\")\n",
"\n",
"# Select desired columns and rename some of them\n",
"world_population = world_population[['Country (or dependency)', 'Population (2020)', 'Density (P/Km²)', 'Land Area (Km²)', 'Med. Age', 'Urban Pop %']]\n",
"world_population.columns = ['Country (or dependency)', 'Population (2020)', 'Density', 'Land Area', 'Med Age', 'Urban Pop']\n",
"\n",
"# Replace United States by US\n",
"world_population.loc[world_population['Country (or dependency)']=='United States', 'Country (or dependency)'] = 'US'\n",
"\n",
"# Remove the % character from Urban Pop values\n",
"world_population['Urban Pop'] = world_population['Urban Pop'].str.rstrip('%')\n",
"\n",
"# Replace Urban Pop and Med Age \"N.A\" by their respective modes, then transform to int\n",
"world_population.loc[world_population['Urban Pop']=='N.A.', 'Urban Pop'] = int(world_population.loc[world_population['Urban Pop']!='N.A.', 'Urban Pop'].mode()[0])\n",
"world_population['Urban Pop'] = world_population['Urban Pop'].astype('int16')\n",
"world_population.loc[world_population['Med Age']=='N.A.', 'Med Age'] = int(world_population.loc[world_population['Med Age']!='N.A.', 'Med Age'].mode()[0])\n",
"world_population['Med Age'] = world_population['Med Age'].astype('int16')\n",
"\n",
"print(\"Cleaned country details dataset\")\n",
"\n",
"\n",
"# Now join the dataset to our previous DataFrame and clean missings (not match in left join)- label encode cities\n",
"print(\"Joined dataset\")\n",
"all_data = all_data.merge(world_population, left_on='Country_Region', right_on='Country (or dependency)', how='left')\n",
"all_data[['Population (2020)', 'Density', 'Land Area', 'Med Age', 'Urban Pop']] = all_data[['Population (2020)', 'Density', 'Land Area', 'Med Age', 'Urban Pop']].fillna(0)\n",
"\n",
"\n",
"print(\"Encoded dataset\")\n",
"# Label encode countries and provinces. Save dictionary for exploration purposes\n",
"all_data.drop('Country (or dependency)', inplace=True, axis=1)\n",
"all_data['Country_Region'] = le.fit_transform(all_data['Country_Region'])\n",
"\n",
"number_c = all_data['Country_Region']\n",
"countries = le.inverse_transform(all_data['Country_Region'])\n",
"country_dict = dict(zip(countries, number_c)) \n",
"all_data['Province_State'].fillna(\"None\", inplace=True)\n",
"all_data['Province_State'] = le.fit_transform(all_data['Province_State'])\n",
"number_p = all_data['Province_State']\n",
"province = le.inverse_transform(all_data['Province_State'])\n",
"province_dict = dict(zip(province, number_p)) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SIR Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "bGq-gMq5oqUV"
},
"outputs": [],
"source": [
"class SIR:\n",
" def __init__(self, beta=0, gamma=0, fix_gamma=False):\n",
" self.beta = beta\n",
" self.gamma = gamma\n",
" self.infected_t0 = 0\n",
" self.fitted_on = np.array([])\n",
" self.fix_gamma = fix_gamma\n",
" self.fitted = False\n",
" \n",
" def ode(self, y, x, beta, gamma):\n",
" '''Defines the ODE that governs the SIRs behaviour'''\n",
" dSdt = -beta * y[0] * y[1]\n",
" dRdt = gamma * y[1]\n",
" dIdt = -(dSdt + dRdt)\n",
" return dSdt, dIdt, dRdt\n",
" \n",
" def solve_ode(self, x, beta, gamma):\n",
" '''Solves the resulting ODE to get predictions for each time step'''\n",
" return np.cumsum(integrate.odeint(self.ode, (1-self.infected_t0, self.infected_t0, 0.0), x, args=(beta, gamma))[:,1])\n",
" \n",
" def solve_ode_fixed(self, x, beta):\n",
" '''Solves the resulting ODE to get predictions for each time step'''\n",
" return np.cumsum(integrate.odeint(self.ode, (1-self.infected_t0, self.infected_t0, 0.0), x, args=(beta, self.gamma))[:,1])\n",
" \n",
" def describe(self):\n",
" assert self.fitted, \"You need to fit the model before describing it!\"\n",
" print(\"Beta: \", self.beta)\n",
" print(\"Gamma: \", self.gamma)\n",
" print(\"At t=0: \", self.infected_t0)\n",
" \n",
" plt.plot(range(1,len(self.fitted_on)+1), self.fitted_on, \"x\", label='Actual')\n",
" plt.plot(range(1,len(self.fitted_on)+1), self.predict(len(self.fitted_on)), label='Prediction')\n",
" plt.title(\"Fit of SIR model to actual\")\n",
" plt.ylabel(\"% of Population\")\n",
" plt.xlabel(\"Days\")\n",
" plt.legend()\n",
" plt.show()\n",
" \n",
" def evaluate(self, y_test):\n",
" assert self.fitted, \"You need to fit the model before evaluating it!\"\n",
" print(\"Beta: \", self.beta)\n",
" print(\"Gamma: \", self.gamma)\n",
" print(\"At t=0: \", self.infected_t0)\n",
" \n",
" y_train = self.fitted_on\n",
" l_train = len(self.fitted_on)\n",
" l_test = len(y_test)\n",
" l_all = l_train + l_test\n",
" \n",
" plt.plot(range(1, l_train + 1), y_train, \"x\", label='Actual Train')\n",
" plt.plot(range(1 + l_train, l_all + 1), y_test, \"x\", label='Actual Test')\n",
" plt.plot(range(1, l_all + 1), self.predict(l_all), label='Prediction')\n",
" plt.title(\"Fit of SIR model to actual\")\n",
" plt.ylabel(\"% of Population\")\n",
" plt.xlabel(\"Days\")\n",
" plt.legend()\n",
" plt.show()\n",
" \n",
" def fit(self, y):\n",
" '''Fits the parameters to the data, assuming the first data point is the start of the outbreak'''\n",
" if len(y) == 1: y = np.array([0, y[0]]) # SIR needs at least 2 datapoints to fit\n",
" self.infected_t0 = y[0]\n",
" x = np.array(range(1,len(y)+1), dtype=float)\n",
" self.fitted_on = y\n",
" if(self.fix_gamma):\n",
" popt, _ = optimize.curve_fit(self.solve_ode_fixed, x, y)\n",
" self.beta = popt[0]\n",
" else:\n",
" popt, _ = optimize.curve_fit(self.solve_ode, x, y, maxfev=1000)\n",
" self.beta = popt[0]\n",
" self.gamma = popt[1]\n",
" self.fitted = True\n",
" \n",
" def predict(self ,length):\n",
" '''Returns the predicted cumulated cases at each time step, assuming outbreak starts at t=0'''\n",
" #assert self.fitted, \"You need to fit the model before predicting!\"\n",
" return self.solve_ode(range(1, length+1), self.beta, self.gamma)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Prep"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "_r7cJkVjnlsn"
},
"outputs": [],
"source": [
"unknown_countries = []\n",
"hardcoded_countries = {\n",
" \"Korea, South\": 51269000,\n",
" \"Diamond Princess\": 3711,\n",
" \"Taiwan*\": 23800000,\n",
" \"Saint Vincent and the Grenadines\": 109897,\n",
" \"Congo (Brazzaville)\":5261000,\n",
" \"Congo (Kinshasa)\":81340000,\n",
" \"Cote d'Ivoire\":24300000,\n",
" \"Czechia\": 10650000,\n",
" \"Saint Kitts and Nevis\": 55345,\n",
" \"Burma\": 53370000,\n",
" \"Kosovo\": 1831000,\n",
" \"MS Zaandam\": 1432, # cruise ship\n",
" \"West Bank and Gaza\": 4685,\n",
" \"Sao Tome and Principe\": 204327,\n",
"}\n",
"hardcoded_province = {\n",
" \"Saint Pierre and Miquelon\": 5888,\n",
" \"Bonaire, Sint Eustatius and Saba\": 25157,\n",
" \"Falkland Islands (Malvinas)\": 2840,\n",
"}\n",
"state_populations= pd.read_csv(\"../input/covid19-forecasting-metadata/region_metadata.csv\")\n",
"\n",
"def get_population(country_name, province_name=None):\n",
" if province_name:\n",
" pop = state_populations[state_populations['Province_State']==province_name]['population']\n",
" if len(pop)==0:\n",
" if province_name in hardcoded_province:\n",
" return hardcoded_province[province_name]\n",
" else:\n",
" print(f\"Warning: We have no province population data at the moment. Instead of data for {province_name}, using data for {country_name}\")\n",
" else:\n",
" return pop.iloc[0]\n",
" \n",
" if country_name in hardcoded_countries:\n",
" return hardcoded_countries[country_name]\n",
" \n",
" pop = all_data[all_data[\"Country_Region\"] == country_dict[country_name]].iloc[0][\"Population (2020)\"]\n",
" if not pop:\n",
" print(f\"population of {country_name} unknown\")\n",
" pop = 100\n",
" unknown_countries.append(country_name)\n",
" \n",
" return pop"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "QELZtpCdQp8P",
"outputId": "09e65ef5-ade6-4800-da95-105f01a41fd3"
},
"outputs": [
{
"data": {
"text/plain": [
"1438116346.0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"country_name = 'China'\n",
"all_data[all_data[\"Country_Region\"] == country_dict[country_name]].iloc[0][\"Population (2020)\"]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "tsjtS2dcCh6j",
"outputId": "fdde0367-1edd-4244-937e-c09f24723c5e"
},
"outputs": [
{
"data": {
"text/plain": [
"1438116346.0"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"country_name = 'Hubei'\n",
"all_data[all_data[\"Province_State\"] == province_dict[country_name]].iloc[0][\"Population (2020)\"]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "B3d-fcFy5bK2"
},
"outputs": [],
"source": [
"def get_country_data(country_name, province_name=None, train_split_factor=1.0):\n",
" if province_name:\n",
" confirmed_total_date_country = train[(train['Country_Region']==country_name) & (train['Province_State']==province_name)].groupby(['Date']).agg({'ConfirmedCases':['sum']})\n",
" fatalities_total_date_country = train[(train['Country_Region']==country_name) & (train['Province_State']==province_name)].groupby(['Date']).agg({'Fatalities':['sum']})\n",
" total_date_country = confirmed_total_date_country.join(fatalities_total_date_country)\n",
"\n",
" cases = total_date_country.ConfirmedCases['sum'].values\n",
" cases_normalized = total_date_country.ConfirmedCases['sum'].values / get_population(country_name, province_name)\n",
" fatalities_normalized = total_date_country.Fatalities['sum'].values / get_population(country_name, province_name)\n",
"\n",
" cases_final = cases_normalized[np.argmax(cases>0):]\n",
" fatalities_final = fatalities_normalized[np.argmax(fatalities_normalized>0):]\n",
"\n",
" cases_length = len(cases_final)\n",
" fat_length = len(fatalities_final)\n",
" cases_split = math.floor(cases_length * train_split_factor)\n",
" fat_split = math.floor(fat_length * train_split_factor)\n",
" else:\n",
" confirmed_total_date_country = train[train['Country_Region']==country_name].groupby(['Date']).agg({'ConfirmedCases':['sum']})\n",
" fatalities_total_date_country = train[train['Country_Region']==country_name].groupby(['Date']).agg({'Fatalities':['sum']})\n",
" total_date_country = confirmed_total_date_country.join(fatalities_total_date_country)\n",
"\n",
" cases = total_date_country.ConfirmedCases['sum'].values\n",
" cases_normalized = cases / get_population(country_name, province_name)\n",
" fatalities_normalized = total_date_country.Fatalities['sum'].values / get_population(country_name, province_name)\n",
"\n",
" cases_final = cases_normalized[np.argmax(cases>0):]\n",
" fatalities_final = fatalities_normalized[np.argmax(fatalities_normalized>0):]\n",
"\n",
" cases_length = len(cases_final)\n",
" fat_length = len(fatalities_final)\n",
" cases_split = math.floor(cases_length * train_split_factor)\n",
" fat_split = math.floor(fat_length * train_split_factor)\n",
" \n",
" return cases_final, fatalities_final, cases_split, fat_split, cases_length, fat_length"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization Helpers"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "DbvCXYEbsV_l"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"def visualize(val_loss_dict, val_info_dict, start=0, end=150):\n",
" fig = plt.figure(figsize=(10,2))\n",
" ax = fig.add_axes([0,0,1,1])\n",
"\n",
" loss_sorted = sorted(val_loss_dict.items(), key=lambda x: x[1], reverse=True)\n",
" print(loss_sorted[10:20])\n",
" losses = [x[1] for x in loss_sorted[start:end]]\n",
" countries = [x[0] for x in loss_sorted[start:end]]\n",
" colors = [val_info_dict[x][\"Color\"] for x in countries]\n",
" ax.bar(countries, losses, color=colors)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "gzpYg0IBtp3h"
},
"outputs": [],
"source": [
"def visualize_country(country_name, val_info_dict=val_info_dict):\n",
" info = val_info_dict[country_name]\n",
" cases_actual = info[\"Cases Actual\"]\n",
" cases_predicted = info[\"Cases Predicted\"]\n",
" cases_split = info[\"Case Split\"]\n",
" fat_actual = info[\"Fatalities Actual\"]\n",
" fat_predicted = info[\"Fatalities Predicted\"]\n",
" fat_split = info[\"Fatality Split\"]\n",
" \n",
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30,15))\n",
"\n",
" ax1.plot(cases_actual, 'o')\n",
" ax1.plot(cases_predicted)\n",
" ax1.axvline(x=cases_split, color='gray', linestyle='--')\n",
" ax1.set_title(\"Fit of SIR model to global infected cases\")\n",
" \n",
" ax2.plot(fat_actual, 'o')\n",
" ax2.plot(fat_predicted)\n",
" ax2.axvline(x=fat_split, color='gray', linestyle='--')\n",
" ax2.set_title(\"Fit of SIR model to global fatalities\")\n",
" \n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "6IvVbeCU7YUz"
},
"outputs": [],
"source": [
"def train_val_country(country_name, train_split_factor=1.0):\n",
" cases, fatalities, case_split, fat_split, case_length, fat_length = get_country_data(country_name, train_split_factor=train_split_factor)\n",
" cases_train = cases[0:case_split]\n",
" cases_test = cases[case_split:]\n",
" fat_train = fatalities[0:fat_split]\n",
" fat_test = fatalities[fat_split:]\n",
" \n",
" case_model = SIR()\n",
" case_model.fit(cases_train)\n",
" fat_model = SIR()\n",
" fat_model.fit(fat_train)\n",
" \n",
" cases_pred_all = case_model.predict(len(cases_train) + len(cases_test))\n",
" cases_pred_train = cases_pred_all[:case_split]\n",
" cases_pred_test = cases_pred_all[case_split:]\n",
" fat_pred_all = fat_model.predict(len(fat_train) + len(fat_test))\n",
" fat_pred_train = fat_pred_all[:fat_split]\n",
" fat_pred_test = fat_pred_all[fat_split:]\n",
" \n",
" if(sum(cases_test) > sum(cases_pred_test)):\n",
" color = \"red\"\n",
" else:\n",
" color = \"blue\"\n",
" \n",
" cases_train_val_loss = np.sqrt(mean_squared_log_error(cases_train, cases_pred_train)) if (len(cases_train) > 0) else 0\n",
" fat_train_val_loss = np.sqrt(mean_squared_log_error(fat_train, fat_pred_train)) if (len(fat_train) > 0) else 0\n",
" cases_test_val_loss = np.sqrt(mean_squared_log_error(cases_test, cases_pred_test)) if (len(cases_test) > 0) else 0\n",
" fat_test_val_loss = np.sqrt(mean_squared_log_error(fat_test, fat_pred_test)) if (len(fat_test) > 0) else 0\n",
" #print(f\"Val Loss for {country_name}: {val_loss}\")\n",
" #print(f\"Sum actual: {sum(cases_test)} Sum predicted: {sum(cases_pred_val)}\")\n",
" val_loss_dict[country_name] = cases_test_val_loss\n",
" results_dict = {\n",
" \"Country\": country_name,\n",
" \"Province\": float('nan'),\n",
" \"Case Model\": case_model,\n",
" \"Fatality Model\": fat_model,\n",
" \"Color\": color,\n",
" \"Cases Predicted\": cases_pred_all,\n",
" \"Cases Actual\": cases,\n",
" \"Fatalities Predicted\": fat_pred_all,\n",
" \"Fatalities Actual\": fatalities,\n",
" \"Cases Loss Train\": cases_train_val_loss,\n",
" \"Fatality Loss Train\": fat_train_val_loss,\n",
" \"Cases Loss Test\": cases_test_val_loss,\n",
" \"Fatality Loss Test\": fat_test_val_loss,\n",
" \"Case Split\": case_split,\n",
" \"Fatality Split\": fat_split,\n",
" \"Case length\": case_length,\n",
" \"Fatality length\": fat_length\n",
" }\n",
" return results_dict\n",
"\n",
"def train_val_province(country_name, province_name, train_split_factor=1.0):\n",
" cases, fatalities, case_split, fat_split, case_length, fat_length = get_country_data(country_name, province_name, train_split_factor=train_split_factor)\n",
" cases_train = cases[0:case_split]\n",
" cases_test = cases[case_split:]\n",
" fat_train = fatalities[0:fat_split]\n",
" fat_test = fatalities[fat_split:]\n",
" \n",
" case_model = SIR()\n",
" case_model.fit(cases_train)\n",
" fat_model = SIR()\n",
" fat_model.fit(fat_train)\n",
" \n",
" cases_pred_all = case_model.predict(len(cases_train) + len(cases_test))\n",
" cases_pred_train = cases_pred_all[:case_split]\n",
" cases_pred_test = cases_pred_all[case_split:]\n",
" fat_pred_all = fat_model.predict(len(fat_train) + len(fat_test))\n",
" fat_pred_train = fat_pred_all[:fat_split]\n",
" fat_pred_test = fat_pred_all[fat_split:]\n",
" \n",
" if(sum(cases_test) > sum(cases_pred_test)):\n",
" color = \"red\"\n",
" else:\n",
" color = \"blue\"\n",
"\n",
" cases_train_val_loss = np.sqrt(mean_squared_log_error(cases_train, cases_pred_train)) if (len(cases_train) > 0) else 0\n",
" fat_train_val_loss = np.sqrt(mean_squared_log_error(fat_train, fat_pred_train)) if (len(fat_train) > 0) else 0\n",
" cases_test_val_loss = np.sqrt(mean_squared_log_error(cases_test, cases_pred_test)) if (len(cases_test) > 0) else 0\n",
" fat_test_val_loss = np.sqrt(mean_squared_log_error(fat_test, fat_pred_test)) if (len(fat_test) > 0) else 0\n",
" #print(f\"Val Loss for {country_name}: {val_loss}\")\n",
" #print(f\"Sum actual: {sum(cases_test)} Sum predicted: {sum(cases_pred_val)}\")\n",
" val_loss_dict[province_name] = cases_test_val_loss\n",
" results_dict = {\n",
" \"Country\": country_name,\n",
" \"Province\": province_name,\n",
" \"Case Model\": case_model,\n",
" \"Fatality Model\": fat_model,\n",
" \"Color\": color,\n",
" \"Cases Predicted\": cases_pred_all,\n",
" \"Cases Actual\": cases,\n",
" \"Fatalities Predicted\": fat_pred_all,\n",
" \"Fatalities Actual\": fatalities,\n",
" \"Cases Loss Train\": cases_train_val_loss,\n",
" \"Fatality Loss Train\": fat_train_val_loss,\n",
" \"Cases Loss Test\": cases_test_val_loss,\n",
" \"Fatality Loss Test\": fat_test_val_loss,\n",
" \"Case Split\": case_split,\n",
" \"Fatality Split\": fat_split,\n",
" \"Case length\": case_length,\n",
" \"Fatality length\": fat_length\n",
" }\n",
" return results_dict"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "LZynM7cQ8fUN"
},
"outputs": [],
"source": [
"country_and_provinces = {}\n",
"only_provinces = {}\n",
"only_country = []\n",
"for country in test['Country_Region'].unique():\n",
" provinces = test[test['Country_Region']==country]['Province_State'].unique()\n",
" \n",
" if len(provinces)>1:\n",
" contains_nan = False\n",
" for province in provinces:\n",
" if type(province) == float:\n",
" contains_nan = True\n",
" if contains_nan:\n",
" country_and_provinces[country] = provinces\n",
" else:\n",
" only_provinces[country] = provinces\n",
" else:\n",
" only_country.append(country)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "eoEqisvajzsY",
"outputId": "0ed82a6b-94a5-4928-8eb5-3948dddc611e"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 184/184 [00:24<00:00, 7.53it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"train_split_factor = 0.9\n",
"\n",
"for country in tqdm(train['Country_Region'].unique()):\n",
" #If we only need to predict for the provinces, not for the whole country\n",
" if country in only_provinces:\n",
" for province in only_provinces[country]:\n",
" val_info_dict[province] = train_val_province(country, province, train_split_factor=train_split_factor)\n",
" \n",
" #If we need to predict for the provinces and for the whole country\n",
" elif country in country_and_provinces:\n",
" for province in country_and_provinces[country]:\n",
" #For the 'nan' province value: Make predictions for the whole country\n",
" if type(province) == float:\n",
" val_info_dict[country] = train_val_country(country, train_split_factor=train_split_factor)\n",
" else:\n",
" val_info_dict[province] = train_val_province(country, province, train_split_factor=train_split_factor)\n",
" \n",
" #If we don't have any provinces for this country\n",
" elif country in only_country:\n",
" val_info_dict[country] = train_val_country(country, train_split_factor=train_split_factor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Results"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Afghanistan': 1.2849805882297129e-06,\n",
" 'Albania': 7.179673514664038e-06,\n",
" 'Algeria': 2.311680547714642e-05,\n",
" 'Andorra': 0.00011380427356961391,\n",
" 'Angola': 1.83245844824193e-08,\n",
" 'Antigua and Barbuda': 7.796286718083674e-05,\n",
" 'Argentina': 4.606806475610179e-05,\n",
" 'Armenia': 0.00027505426536992997,\n",
" 'Australian Capital Territory': 8.853711342700853e-05,\n",
" 'New South Wales': 4.997783712991756e-05,\n",
" 'Northern Territory': 2.7453079591036178e-05,\n",
" 'Queensland': 5.73788133074826e-05,\n",
" 'South Australia': 1.8425943996740522e-05,\n",
" 'Tasmania': 3.569775087839743e-05,\n",
" 'Victoria': 0.00014693735369897645,\n",
" 'Western Australia': 9.654178851669384e-05,\n",
" 'Austria': 0.0010463090882809523,\n",
" 'Azerbaijan': 1.0059204702777996e-05,\n",
" 'Bahamas': 2.1788486201736117e-06,\n",
" 'Bahrain': 0.00026743168585066656,\n",
" 'Bangladesh': 3.1835889689564354e-06,\n",
" 'Barbados': 2.2303183829751415e-05,\n",
" 'Belarus': 0.00011702311186299479,\n",
" 'Belgium': 0.0004383376194654168,\n",
" 'Belize': 2.3887580549634123e-05,\n",
" 'Benin': 4.2369033498519815e-07,\n",
" 'Bhutan': 2.680172100714255e-06,\n",
" 'Bolivia': 7.010622961859406e-07,\n",
" 'Bosnia and Herzegovina': 0.0002075848205904463,\n",
" 'Botswana': 1.732374238185675e-06,\n",
" 'Brazil': 8.453395192549809e-05,\n",
" 'Brunei': 1.9534378280180854e-05,\n",
" 'Bulgaria': 4.686678489912542e-05,\n",
" 'Burkina Faso': 1.7163778218086352e-05,\n",
" 'Burma': 7.081479220440122e-07,\n",
" 'Burundi': 1.3641049933980924e-07,\n",
" 'Cabo Verde': 5.414606538709324e-06,\n",
" 'Cambodia': 1.5944615629192908e-06,\n",
" 'Cameroon': 2.109961616750171e-05,\n",
" 'Alberta': 0.00032055529594577834,\n",
" 'British Columbia': 0.00021677780292031907,\n",
" 'Manitoba': 3.309409172838959e-05,\n",
" 'New Brunswick': 1.3194813525915369e-05,\n",
" 'Newfoundland and Labrador': 1.2413172370969668e-05,\n",
" 'Northwest Territories': 1.8142081563949042e-05,\n",
" 'Nova Scotia': 7.222807860571439e-05,\n",
" 'Ontario': 8.772414666261898e-05,\n",
" 'Prince Edward Island': 2.129581461488591e-05,\n",
" 'Quebec': 0.0014828736626624401,\n",
" 'Saskatchewan': 6.3715023688719484e-06,\n",
" 'Yukon': 4.1650957622535295e-05,\n",
" 'Central African Republic': 8.953086537405141e-07,\n",
" 'Chad': 7.407117368522066e-07,\n",
" 'Chile': 0.0004153297743491753,\n",
" 'Anhui': 2.1566689139812655e-05,\n",
" 'Beijing': 2.6229565948391174e-06,\n",
" 'Chongqing': 5.7631453832574635e-08,\n",
" 'Fujian': 7.6019198198622304e-06,\n",
" 'Gansu': 1.6451591438256645e-06,\n",
" 'Guangdong': 1.3051481502116888e-06,\n",
" 'Guangxi': 3.0594524991051327e-06,\n",
" 'Guizhou': 2.5156739997183374e-06,\n",
" 'Hainan': 3.819901166330451e-06,\n",
" 'Hebei': 4.142053536523168e-06,\n",
" 'Heilongjiang': 7.238223756184245e-06,\n",
" 'Henan': 1.2236664988679813e-05,\n",
" 'Hong Kong': 9.763761425172463e-05,\n",
" 'Hubei': 1.0548900721327478e-05,\n",
" 'Hunan': 1.3637809109980497e-05,\n",
" 'Inner Mongolia': 1.2959307213464633e-06,\n",
" 'Jiangsu': 9.609667934990638e-06,\n",
" 'Jiangxi': 2.2755236066561313e-05,\n",
" 'Jilin': 1.2317509196038182e-06,\n",
" 'Liaoning': 9.553044151037838e-07,\n",
" 'Macau': 3.534290212370054e-05,\n",
" 'Ningxia': 5.308178505254384e-09,\n",
" 'Qinghai': 6.613373443466151e-08,\n",
" 'Shaanxi': 2.5261381292458854e-08,\n",
" 'Shandong': 7.885347476129242e-06,\n",
" 'Shanghai': 3.851036574706187e-06,\n",
" 'Shanxi': 1.0783754923652749e-06,\n",
" 'Sichuan': 3.567460498739672e-07,\n",
" 'Tianjin': 2.3862196846591877e-06,\n",
" 'Tibet': 1.5620604608296038e-12,\n",
" 'Xinjiang': 7.747761139738117e-07,\n",
" 'Yunnan': 2.846683448461055e-06,\n",
" 'Zhejiang': 2.0599924368058365e-06,\n",
" 'Colombia': 3.9084086457108786e-05,\n",
" 'Congo (Brazzaville)': 8.176248984776057e-07,\n",
" 'Congo (Kinshasa)': 1.1453189268570956e-06,\n",
" 'Costa Rica': 8.688118111562011e-05,\n",
" \"Cote d'Ivoire\": 7.204477644459691e-06,\n",
" 'Croatia': 0.0002590150833213281,\n",
" 'Cuba': 4.0840849786138845e-06,\n",
" 'Cyprus': 6.243690161459026e-06,\n",
" 'Czechia': 0.0005871861150005884,\n",
" 'Faroe Islands': 0.0001305659256466765,\n",
" 'Greenland': 3.603043499316395e-05,\n",
" 'Denmark': 0.0012591971527977609,\n",
" 'Diamond Princess': 0.0028284886543598346,\n",
" 'Djibouti': 7.444796408205375e-05,\n",
" 'Dominica': 2.0942769012272434e-05,\n",
" 'Dominican Republic': 0.000162158723695759,\n",
" 'Ecuador': 7.739318760134494e-05,\n",
" 'Egypt': 1.6672291572369493e-06,\n",
" 'El Salvador': 4.699286997379525e-06,\n",
" 'Equatorial Guinea': 7.714263286201959e-06,\n",
" 'Eritrea': 1.474604984397985e-06,\n",
" 'Estonia': 0.0006738493427564092,\n",
" 'Eswatini': 3.5618437456067586e-06,\n",
" 'Ethiopia': 5.100515059015048e-08,\n",
" 'Fiji': 2.4327863474647453e-06,\n",
" 'Finland': 0.0003279808210551552,\n",
" 'French Guiana': 0.00013398949436807886,\n",
" 'French Polynesia': 8.123726486541588e-05,\n",
" 'Guadeloupe': 1.469638697756939e-05,\n",
" 'Martinique': 7.703353094394305e-05,\n",
" 'Mayotte': 2.531573584602383e-05,\n",
" 'New Caledonia': 1.1830528960676557e-06,\n",
" 'Reunion': 5.792214707737096e-06,\n",
" 'Saint Barthelemy': 0.00023406160201836092,\n",
" 'Saint Pierre and Miquelon': 1.1677838844580471e-11,\n",
" 'St Martin': 0.0002135746509242852,\n",
" 'France': 0.0010104327526954083,\n",
" 'Gabon': 7.510302496475318e-06,\n",
" 'Gambia': 2.1036881987238083e-06,\n",
" 'Georgia': 0.0010140726625794663,\n",
" 'Germany': 8.409149255787794e-05,\n",
" 'Ghana': 1.6074047176706742e-06,\n",
" 'Greece': 0.00020463048801767328,\n",
" 'Grenada': 4.85706957133777e-06,\n",
" 'Guatemala': 9.596880499085776e-07,\n",
" 'Guinea': 2.697362538223174e-06,\n",
" 'Guinea-Bissau': 9.428352208242373e-07,\n",
" 'Guyana': 1.5133697855269813e-05,\n",
" 'Haiti': 8.268039774387613e-07,\n",
" 'Holy See': 0.004411124109562615,\n",
" 'Honduras': 1.5566439546521945e-05,\n",
" 'Hungary': 2.4688742074428397e-05,\n",
" 'Iceland': 0.004561714116364317,\n",
" 'India': 2.6424478286460263e-06,\n",
" 'Indonesia': 7.544530203816871e-06,\n",
" 'Iran': 0.0007038178820593384,\n",
" 'Iraq': 2.0548941622067137e-05,\n",
" 'Ireland': 0.0016420296695800357,\n",
" 'Israel': 3.740961728137474e-05,\n",
" 'Italy': 0.0019895647310509694,\n",
" 'Jamaica': 1.8741651045458255e-06,\n",
" 'Japan': 1.441285653047277e-05,\n",
" 'Jordan': 3.144443463861551e-05,\n",
" 'Kazakhstan': 1.4400188126095315e-05,\n",
" 'Kenya': 1.9731091003428376e-06,\n",
" 'Korea, South': 0.00036406925092194334,\n",
" 'Kosovo': 9.450696026603212e-05,\n",
" 'Kuwait': 0.00021356200372572363,\n",
" 'Kyrgyzstan': 1.9122272968050875e-06,\n",
" 'Laos': 5.821354311149184e-07,\n",
" 'Latvia': 0.00021026446444712913,\n",
" 'Lebanon': 6.374786835181886e-05,\n",
" 'Liberia': 6.046388765133942e-06,\n",
" 'Libya': 5.114971608578015e-07,\n",
" 'Liechtenstein': 0.00023814207843863712,\n",
" 'Lithuania': 2.9198335717654533e-05,\n",
" 'Luxembourg': 0.00034153927536285613,\n",
" 'MS Zaandam': 0.0003046000357504367,\n",
" 'Madagascar': 1.4801136211412838e-07,\n",
" 'Malawi': 3.439503197558827e-07,\n",
" 'Malaysia': 1.547418445013923e-05,\n",
" 'Maldives': 6.935977480746684e-06,\n",
" 'Mali': 4.229233591169801e-07,\n",
" 'Malta': 6.0665659023603595e-05,\n",
" 'Mauritania': 4.297897435663005e-07,\n",
" 'Mauritius': 6.197912849760222e-05,\n",
" 'Mexico': 1.0069840701698503e-05,\n",
" 'Moldova': 0.00021774708507984954,\n",
" 'Monaco': 0.00014323627567936414,\n",
" 'Mongolia': 2.256785782348178e-06,\n",
" 'Montenegro': 9.975932465295745e-06,\n",
" 'Morocco': 2.4503481414768627e-05,\n",
" 'Mozambique': 1.8472220968626146e-07,\n",
" 'Namibia': 1.403074733748481e-06,\n",
" 'Nepal': 3.17916386856687e-07,\n",
" 'Aruba': 6.460452522259138e-05,\n",
" 'Bonaire, Sint Eustatius and Saba': 3.5715945145786376e-05,\n",
" 'Curacao': 8.319910846078059e-06,\n",
" 'Sint Maarten': 0.00020481917519548254,\n",
" 'Netherlands': 0.00032544515198131566,\n",
" 'New Zealand': 0.000212729796348943,\n",
" 'Nicaragua': 5.214825941961771e-07,\n",
" 'Niger': 1.4896733742015378e-05,\n",
" 'Nigeria': 2.1804780524167706e-07,\n",
" 'North Macedonia': 0.0002547146718233912,\n",
" 'Norway': 0.0007488766880086114,\n",
" 'Oman': 6.04687007386788e-05,\n",
" 'Pakistan': 1.3605953260894951e-05,\n",
" 'Panama': 0.0007110676844440865,\n",
" 'Papua New Guinea': 8.0581307462877e-08,\n",
" 'Paraguay': 4.611596197112679e-06,\n",
" 'Peru': 5.921097639621376e-05,\n",
" 'Philippines': 4.51311391563911e-06,\n",
" 'Poland': 0.00019926593035117531,\n",
" 'Portugal': 0.0020090052825980516,\n",
" 'Qatar': 2.814324358593336e-05,\n",
" 'Romania': 0.00023183626336175473,\n",
" 'Russia': 5.35960394170891e-05,\n",
" 'Rwanda': 5.605379737950622e-06,\n",
" 'Saint Kitts and Nevis': 3.8142097296000895e-05,\n",
" 'Saint Lucia': 1.0111810024833086e-05,\n",
" 'Saint Vincent and the Grenadines': 7.502198593195164e-05,\n",
" 'San Marino': 0.0023135838973522157,\n",
" 'Sao Tome and Principe': 4.704141603703062e-13,\n",
" 'Saudi Arabia': 7.774870833088573e-05,\n",
" 'Senegal': 7.554969102456319e-06,\n",
" 'Serbia': 0.0002712179053672993,\n",
" 'Seychelles': 1.7494044915745855e-05,\n",
" 'Sierra Leone': 1.2114273028744893e-07,\n",
" 'Singapore': 0.00033238046861956027,\n",
" 'Slovakia': 9.638337659743148e-05,\n",
" 'Slovenia': 0.0003269185428750306,\n",
" 'Somalia': 2.6961864616424584e-06,\n",
" 'South Africa': 4.556875544337852e-05,\n",
" 'South Sudan': 1.3682715018445416e-08,\n",
" 'Spain': 6.747567701671771e-05,\n",
" 'Sri Lanka': 3.0040664491785722e-06,\n",
" 'Sudan': 3.643966547857473e-07,\n",
" 'Suriname': 3.447734893930676e-07,\n",
" 'Sweden': 9.488454021709986e-05,\n",
" 'Switzerland': 0.002024124499842536,\n",
" 'Syria': 1.5834271637320346e-07,\n",
" 'Taiwan*': 1.2218999978852254e-06,\n",
" 'Tanzania': 2.3987083857060715e-07,\n",
" 'Thailand': 2.969099381962298e-06,\n",
" 'Timor-Leste': 2.5355151317515834e-06,\n",
" 'Togo': 6.121565495982997e-07,\n",
" 'Trinidad and Tobago': 3.3820625104494725e-06,\n",
" 'Tunisia': 4.963227922583118e-05,\n",
" 'Turkey': 0.0011840397491182512,\n",
" 'Alabama': 0.00011006440910276353,\n",
" 'Alaska': 0.00020733281629836502,\n",
" 'Arizona': 3.262323188611717e-05,\n",
" 'Arkansas': 0.00038769623833965393,\n",
" 'California': 0.0001028189104813389,\n",
" 'Colorado': 0.000898060296190525,\n",
" 'Connecticut': 0.0005916389731664958,\n",
" 'Delaware': 2.1740967237831155e-05,\n",
" 'District of Columbia': 0.000569204030709808,\n",
" 'Florida': 7.441098676719035e-05,\n",
" 'Guam': 0.0001624030893381615,\n",
" 'Hawaii': 2.6354044452155405e-05,\n",
" 'Idaho': 0.0010328422094282828,\n",
" 'Illinois': 0.00027059063784426246,\n",
" 'Indiana': 3.808868595039972e-05,\n",
" 'Iowa': 0.00028386351044947514,\n",
" 'Kansas': 0.00033797457258202184,\n",
" 'Kentucky': 0.00029640422854047577,\n",
" 'Louisiana': 0.0003446505811545775,\n",
" 'Maine': 8.395350528971605e-05,\n",
" 'Maryland': 0.0010391429532535983,\n",
" 'Massachusetts': 0.00025848147353538844,\n",
" 'Michigan': 0.0018415825748126415,\n",
" 'Minnesota': 4.7428207312129696e-05,\n",
" 'Mississippi': 0.0010291939210410152,\n",
" 'Missouri': 0.0007866767544571429,\n",
" 'Montana': 0.00022008373963125228,\n",
" 'Nebraska': 0.000252896451933831,\n",
" 'Nevada': 4.5528716483467216e-05,\n",
" 'New Hampshire': 0.00013826413183675714,\n",
" 'New Jersey': 0.005348810023114104,\n",
" 'New Mexico': 0.0004152589380814988,\n",
" 'New York': 0.0011330844063905455,\n",
" 'North Carolina': 5.171026856527207e-05,\n",
" 'North Dakota': 0.0002419396036081185,\n",
" 'Ohio': 0.0005485756202291289,\n",
" 'Oklahoma': 0.0003663325351307428,\n",
" 'Oregon': 2.957204035272494e-05,\n",
" 'Pennsylvania': 0.0013292273208941523,\n",
" 'Puerto Rico': 1.9368596885678292e-05,\n",
" 'Rhode Island': 0.00026802882531775713,\n",
" 'South Carolina': 0.0004608273098738372,\n",
" 'South Dakota': 0.0004992767346474372,\n",
" 'Tennessee': 9.879279515508202e-05,\n",
" 'Texas': 0.0002788736212016803,\n",
" 'Utah': 0.0004892900644267518,\n",
" 'Vermont': 0.0008035736622334212,\n",
" 'Virgin Islands': 2.0756865298749132e-05,\n",
" 'Virginia': 1.4701512393884831e-05,\n",
" 'Washington': 0.00014259221679922322,\n",
" 'West Virginia': 0.00028123316420640594,\n",
" 'Wisconsin': 0.0005274779174654732,\n",
" 'Wyoming': 1.1603978586029659e-05,\n",
" 'Uganda': 8.311931995273981e-07,\n",
" 'Ukraine': 1.588412645597683e-05,\n",
" 'United Arab Emirates': 0.00026061449844999606,\n",
" 'Anguilla': 1.8342333564709167e-05,\n",
" 'Bermuda': 6.741287357772619e-05,\n",
" 'British Virgin Islands': 2.189531432540795e-06,\n",
" 'Cayman Islands': 0.00011804547269321272,\n",
" 'Channel Islands': 0.00020187854351237362,\n",
" 'Falkland Islands (Malvinas)': 0.0011110260590320661,\n",
" 'Gibraltar': 0.0009489182099183034,\n",
" 'Isle of Man': 0.0002602646112006958,\n",
" 'Montserrat': 0.0007721579600352,\n",
" 'Turks and Caicos Islands': 0.00010418803030155613,\n",
" 'United Kingdom': 0.0003282761078357052,\n",
" 'Uruguay': 1.5091083130086797e-05,\n",
" 'Uzbekistan': 8.690748043412879e-06,\n",
" 'Venezuela': 4.0066974051167964e-06,\n",
" 'Vietnam': 9.59451065278107e-07,\n",
" 'West Bank and Gaza': 0.0026469110335335735,\n",
" 'Western Sahara': 2.485149350976438e-06,\n",
" 'Zambia': 7.593249524086256e-08,\n",
" 'Zimbabwe': 2.8474272787437325e-07}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val_loss_dict"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('Ireland', 0.0016420296695800357), ('Quebec', 0.0014828736626624401), ('Pennsylvania', 0.0013292273208941523), ('Denmark', 0.0012591971527977609), ('Turkey', 0.0011840397491182512), ('New York', 0.0011330844063905455), ('Falkland Islands (Malvinas)', 0.0011110260590320661), ('Austria', 0.0010463090882809523), ('Maryland', 0.0010391429532535983), ('Idaho', 0.0010328422094282828)]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# losses\n",
"visualize(val_loss_dict, val_info_dict, end=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evaluate Country/Province by Example"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(name, val_info_dict=val_info_dict):\n",
" info = val_info_dict[name]\n",
" \n",
" case_split = info[\"Case Split\"]\n",
" fat_split = info[\"Fatality Split\"]\n",
" cases_test = info[\"Cases Actual\"][case_split:]\n",
" fat_test = info[\"Fatalities Actual\"][fat_split:]\n",
" \n",
" case_model = info[\"Case Model\"]\n",
" fat_model = info[\"Fatality Model\"]\n",
" \n",
" print(name)\n",
" print(\"Confirmed Cases:\")\n",
" print(\" Loss Train: \", info[\"Cases Loss Train\"])\n",
" print(\" Loss Test: \", info[\"Cases Loss Test\"])\n",
" display(case_model.evaluate(cases_test))\n",
" print(\"Fatalities:\")\n",
" print(\" Loss Train: \", info[\"Fatality Loss Train\"])\n",
" print(\" Loss Test: \", info[\"Fatality Loss Test\"])\n",
" display(fat_model.evaluate(fat_test))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Germany\n",
"Confirmed Cases:\n",
" Loss Train: 2.33381364090002e-05\n",
" Loss Test: 8.409149255787794e-05\n",
"Beta: 12.531149570228767\n",
"Gamma: 12.378434406736405\n",
"At t=0: 1.1943743582253332e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 1.8398965388451636e-06\n",
" Loss Test: 2.5670051269381823e-05\n",
"Beta: 1.1733371392603076\n",
"Gamma: 1.0014755640373423\n",
"At t=0: 2.3887487164506663e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"Germany\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Spain\n",
"Confirmed Cases:\n",
" Loss Train: 4.052282895712593e-05\n",
" Loss Test: 6.747567701671771e-05\n",
"Beta: 9.307848945525745\n",
"Gamma: 9.13324577721468\n",
"At t=0: 2.138997215432035e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 6.432470064248784e-05\n",
" Loss Test: 0.0006616223141886673\n",
"Beta: 0.7732031316965112\n",
"Gamma: 0.5483375624028562\n",
"At t=0: 2.138997215432035e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"Spain\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hubei\n",
"Confirmed Cases:\n",
" Loss Train: 4.226673712063659e-05\n",
" Loss Test: 1.0548900721327478e-05\n",
"Beta: 17.29318273461616\n",
"Gamma: 17.12789337739622\n",
"At t=0: 7.522873602168756e-06\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 8.224879625804729e-06\n",
" Loss Test: 9.56927970034335e-06\n",
"Beta: 99.84970312450345\n",
"Gamma: 99.63317697304076\n",
"At t=0: 2.8803795323619113e-07\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"Hubei\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Italy\n",
"Confirmed Cases:\n",
" Loss Train: 0.0006866289274446315\n",
" Loss Test: 0.0019895647310509694\n",
"Beta: 111.91787349023794\n",
"Gamma: 91.76769924641862\n",
"At t=0: 3.3068304103267285e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 5.5144561469408936e-05\n",
" Loss Test: 0.000534957050231585\n",
"Beta: 0.9789425169319386\n",
"Gamma: 0.80319836997655\n",
"At t=0: 1.6534152051633643e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"Italy\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New York\n",
"Confirmed Cases:\n",
" Loss Train: 0.00021133868939752273\n",
" Loss Test: 0.0011330844063905455\n",
"Beta: -7.061149392747142\n",
"Gamma: -7.299197668607935\n",
"At t=0: 8.89297337387227e-06\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 3.9920989201465346e-05\n",
" Loss Test: 0.000540564087435805\n",
"Beta: 1.113548257206262\n",
"Gamma: 0.8549132026101509\n",
"At t=0: 1.0280894073840773e-07\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"New York\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"India\n",
"Confirmed Cases:\n",
" Loss Train: 2.61875346044632e-07\n",
" Loss Test: 2.6424478286460263e-06\n",
"Beta: 0.05846609781595513\n",
"Gamma: -0.022176095730341433\n",
"At t=0: 7.262104630499392e-10\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 1.460631700361812e-08\n",
" Loss Test: 5.0982790497263404e-08\n",
"Beta: 1.1839307316481835\n",
"Gamma: 1.0863128287203194\n",
"At t=0: 7.262104630499392e-10\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"India\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"France\n",
"Confirmed Cases:\n",
" Loss Train: 2.986904252118738e-05\n",
" Loss Test: 0.0010104327526954083\n",
"Beta: 0.952407147680054\n",
"Gamma: 0.8341517347357443\n",
"At t=0: 3.0654974021672176e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fatalities:\n",
" Loss Train: 4.635690407049928e-06\n",
" Loss Test: 9.410958672959614e-05\n",
"Beta: 1.0390972416688997\n",
"Gamma: 0.9022561254969577\n",
"At t=0: 1.5327487010836088e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"None"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"evaluate(\"France\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Submission"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "FgSEco90VBdM",
"outputId": "4f78feff-3f74-416e-9b3e-979b31ae8640"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"submission saved to csv.\n"
]
}
],
"source": [
"# submission date range: 02Apr20-14May20\n",
"pd_daterange_submission = pd.date_range(\"02Apr20\", \"14May20\") #TODO get from test dataset: min/max of Date\n",
"length_submission = len(pd_daterange_submission)\n",
"\n",
"def make_submission(val_info_dict=val_info_dict, name=\"submission\"):\n",
" # generate submission frames for all items in val_info_dict\n",
" frames = []\n",
" for attr, item in val_info_dict.items():\n",
" country = item[\"Country\"]\n",
" province = item[\"Province\"]\n",
" case_length = item[\"Case length\"]\n",
" fat_length = item[\"Fatality length\"]\n",
" case_model = item[\"Case Model\"]\n",
" fat_model = item[\"Fatality Model\"]\n",
"\n",
" if(type(province)==float):\n",
" pop = get_population(country)\n",
" else:\n",
" pop = get_population(country, province)\n",
" \n",
" case_preds = pop * case_model.predict(case_length + length_submission)[case_length:]\n",
" fat_preds = pop * fat_model.predict(fat_length + length_submission)[fat_length:]\n",
"\n",
" frames.append(pd.DataFrame({\n",
" \"Country_Region\": country,\n",
" \"Province_State\": province,\n",
" \"Date\": pd_daterange_submission,\n",
" \"ConfirmedCases\": case_preds,\n",
" \"Fatalities\": fat_preds\n",
" })\n",
" )\n",
" \n",
" # concat sub frames and prepare for mergeing with test to get ForecastId\n",
" submission_data = pd.concat(frames)\n",
" submission = test.copy()\n",
"\n",
" index = [\"id\", \"Date\"]\n",
" submission[\"id\"] = submission[\"Country_Region\"].astype(str) + \"_\" + submission[\"Province_State\"].astype(str)\n",
" submission = submission[[\"id\", \"Date\", \"ForecastId\"]].set_index(index)\n",
"\n",
" submission_data[\"id\"] = submission_data[\"Country_Region\"].astype(str) + \"_\" + submission_data[\"Province_State\"].astype(str)\n",
" submission_data = submission_data[[\"id\", \"Date\", \"ConfirmedCases\", \"Fatalities\"]].set_index(index)\n",
"\n",
" # merge w/ ForecastId and extract submission columns\n",
" submission = submission.join(submission_data)\n",
" submission = submission[[\"ForecastId\", \"ConfirmedCases\", \"Fatalities\"]]\n",
"\n",
" # fillna (China)\n",
" submission = submission.fillna(1)\n",
" \n",
" # write to csv\n",
" submission.to_csv(name + \".csv\", index=False)\n",
"\n",
" print(\"submission saved to csv.\")\n",
" \n",
"make_submission()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [WIP] Temporal SIR-Model\n",
"Add temporal variability to SIR-Model's parameters"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"class SIRT:\n",
" def __init__(self, gamma=0, a=0, b=0, c=0, d=0, fix_gamma=False):\n",
" self.gamma = gamma\n",
" self.a = a\n",
" self.b = b\n",
" self.c = c\n",
" self.d = d\n",
" self.infected_t0 = 0\n",
" self.fitted_on = np.array([])\n",
" self.fix_gamma = fix_gamma\n",
" self.fitted = False\n",
" \n",
" def ode(self, y, timestep, c, d, gamma):\n",
" '''Defines the ODE that governs the SIRs behaviour'''\n",
" beta = c * timestep + d\n",
" \n",
" dSdt = -beta * y[0] * y[1]\n",
" dRdt = gamma * y[1]\n",
" dIdt = -(dSdt + dRdt)\n",
" return dSdt, dIdt, dRdt\n",
" \n",
" def solve_ode(self, x, c, d, gamma):\n",
" '''Solves the resulting ODE to get predictions for each time step'''\n",
" return np.cumsum(integrate.odeint(self.ode, (1-self.infected_t0, self.infected_t0, 0.0), x, args=(c, d, gamma))[:,1])\n",
" \n",
" def solve_ode_fixed(self, x, beta):\n",
" '''Solves the resulting ODE to get predictions for each time step'''\n",
" return np.cumsum(integrate.odeint(self.ode, (1-self.infected_t0, self.infected_t0, 0.0), x, args=(beta, self.gamma))[:,1])\n",
" \n",
" def describe(self):\n",
" assert self.fitted, \"You need to fit the model before describing it!\"\n",
" print(\"c: \", self.c)\n",
" print(\"d: \", self.d)\n",
" print(\"Gamma: \", self.gamma)\n",
" print(\"Infected at t=0: \", self.infected_t0)\n",
" \n",
" plt.plot(range(1,len(self.fitted_on)+1), self.fitted_on, \"x\", label='Actual')\n",
" plt.plot(range(1,len(self.fitted_on)+1), self.predict(len(self.fitted_on)), label='Prediction')\n",
" plt.title(\"Fit of SIR model to global infected cases\")\n",
" plt.ylabel(\"Population infected\")\n",
" plt.xlabel(\"Days\")\n",
" plt.legend()\n",
" plt.show()\n",
" \n",
" def fit(self, y):\n",
" '''Fits the parameters to the data, assuming the first data point is the start of the outbreak'''\n",
" self.infected_t0 = y[0]\n",
" x = np.array(range(1,len(y)+1), dtype=float)\n",
" self.fitted_on = y\n",
" if(self.fix_gamma):\n",
" popt, _ = optimize.curve_fit(self.solve_ode_fixed, x, y)\n",
" self.beta = popt[0]\n",
" else:\n",
" popt, _ = optimize.curve_fit(self.solve_ode, x, y)\n",
" self.c = popt[0]\n",
" self.d = popt[1]\n",
" self.gamma = popt[2]\n",
" self.fitted = True\n",
" \n",
" def predict(self ,length):\n",
" '''Returns the predicted cumulated cases at each time step, assuming outbreak starts at t=0'''\n",
" #assert self.fitted, \"You need to fit the model before predicting!\"\n",
" return self.solve_ode(range(1, length+1), self.c, self.d, self.gamma)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"measures = pd.read_csv(\"../input/covid19-containment-and-mitigation-measures/COVID 19 Containment measures data.csv\")\n",
"measures[\"Keywords\"].fillna(value=\"-\", inplace=True)\n",
"measures[\"Country\"] = measures[\"Country\"].str.replace('South Korea', 'Korea, South', regex=True)\n",
"measures[\"Country\"] = measures[\"Country\"].str.replace('US:Georgia', 'US', regex=True)\n",
"measures[\"Country\"] = measures[\"Country\"].str.replace('US: Illinois', 'US', regex=True)\n",
"measures[\"Country\"] = measures[\"Country\"].str.replace('US:Maryland', 'US', regex=True)\n",
"\n",
"measures = measures[measures[\"Country\"] != \"Vatican City\"]\n",
"measures = measures[measures[\"Country\"] != \"Hong Kong\"]\n",
"\n",
"def get_measures(measure_name):\n",
" \n",
" took_measure = measures[measures[\"Keywords\"].str.contains(\"distancing\")]\n",
" output = pd.DataFrame(data=0,\n",
" columns=train['Country_Region'].unique(),\n",
" index=pd.date_range(\"02.01.2020\", \"03.01.2020\"))\n",
" \n",
" print(took_measure)\n",
" \n",
" for index, row in took_measure.iterrows():\n",
" output[row[\"Country\"]][pd.to_datetime(row[\"Date Start\"]):] = 1\n",
" return output\n",
" \n",
"#get_measures(\"distancing\")[\"Italy\"]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c: 45.99821649993995\n",
"d: 2081.254522256052\n",
"Gamma: 7.184669771794915\n",
"Infected at t=0: 2.138997215432035e-08\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"model = SIRT()\n",
"c, _, case_split, _, case_length, _ = get_country_data(\"Spain\")\n",
"model.fit(c[:case_split])\n",
"model.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [WIP] Modeling SIR Parameters\n",
"Predicting SIR parameters from Country/Province Metadata"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.13899722e-08, 2.13899722e-08, 2.13899722e-08, 2.13899722e-08,\n",
" 2.13899722e-08, 2.13899722e-08, 2.13899722e-08, 2.13899722e-08,\n",
" 4.27799443e-08, 4.27799443e-08, 4.27799443e-08, 4.27799443e-08,\n",
" 4.27799443e-08, 4.27799443e-08, 4.27799443e-08, 4.27799443e-08,\n",
" 4.27799443e-08, 4.27799443e-08, 4.27799443e-08, 4.27799443e-08,\n",
" 4.27799443e-08, 4.27799443e-08, 4.27799443e-08, 4.27799443e-08,\n",
" 1.28339833e-07, 2.78069638e-07, 3.20849582e-07, 6.84479109e-07,\n",
" 9.62548747e-07, 1.79675766e-06, 2.56679666e-06, 3.52934541e-06,\n",
" 4.74857382e-06, 5.54000279e-06, 8.55598886e-06, 1.06949861e-05,\n",
" 1.43954513e-05, 2.29514401e-05, 3.62560028e-05, 4.87049666e-05,\n",
" 4.87049666e-05, 1.11912334e-04, 1.36703312e-04, 1.66799003e-04,\n",
" 2.12659103e-04, 2.51289393e-04, 2.97534513e-04, 3.84228070e-04,\n",
" 4.36569332e-04, 5.42749153e-04, 6.15346719e-04, 7.51558062e-04,\n",
" 8.53139039e-04, 1.05912447e-03, 1.23604093e-03, 1.40572758e-03,\n",
" 1.56649461e-03, 1.71355067e-03, 1.88137639e-03, 2.05179030e-03,\n",
" 2.22708112e-03, 2.39706723e-03, 2.54966329e-03, 2.69873001e-03,\n",
" 2.81590427e-03, 2.92347444e-03, 3.03613543e-03, 3.17042167e-03,\n",
" 3.27741431e-03, 3.38545506e-03, 3.48714299e-03, 3.56851044e-03,\n",
" 3.63841287e-03, 3.69064719e-03]),\n",
" array([2.13899722e-08, 4.27799443e-08, 6.41699165e-08, 1.06949861e-07,\n",
" 2.13899722e-07, 3.63629527e-07, 5.98919220e-07, 7.48649025e-07,\n",
" 1.15505850e-06, 1.17644847e-06, 2.84486630e-06, 4.17104457e-06,\n",
" 6.18170195e-06, 7.31537048e-06, 1.14008552e-05, 1.33259527e-05,\n",
" 1.77536769e-05, 2.23097410e-05, 2.94112117e-05, 3.79030307e-05,\n",
" 4.94322256e-05, 6.00630418e-05, 7.80092284e-05, 9.33672285e-05,\n",
" 1.09901677e-04, 1.27954813e-04, 1.45515981e-04, 1.65045025e-04,\n",
" 1.81044724e-04, 2.00787669e-04, 2.21343432e-04, 2.39524908e-04,\n",
" 2.55545997e-04, 2.70390638e-04, 2.85363619e-04, 3.00422159e-04,\n",
" 3.16400468e-04, 3.30410900e-04, 3.43972142e-04, 3.55201878e-04,\n",
" 3.68100031e-04, 3.79800346e-04, 3.86217337e-04]),\n",
" 74,\n",
" 43,\n",
" 74,\n",
" 43)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_country_data(\"Spain\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"for attr, item in val_info_dict.items():\n",
" country = item[\"Country\"]\n",
" province = item[\"Province\"]\n",
" case_length = item[\"Case length\"]\n",
" fat_length = item[\"Fatality length\"]\n",
" case_model = item[\"Case Model\"]\n",
" fat_model = item[\"Fatality Model\"]"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "LeoCorona",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment