Last active
March 6, 2020 13:41
-
-
Save ctung/b31726c64e55b7ce48887f98b52c6acf to your computer and use it in GitHub Desktop.
covid-19 death probability function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! python3 | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as mtick | |
import datetime | |
import numpy as np | |
from lmfit import Parameters, fit_report, minimize | |
from pprint import pprint | |
from scipy.stats import norm | |
sns.set() | |
urls = { | |
"confirmed": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv", | |
"deaths": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Deaths.csv", | |
"recovered": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Recovered.csv" | |
} | |
def main(): | |
df = readData() | |
# merge countries into global total | |
df = df[['date','deaths','recovered','confirmed']].groupby(['date']).sum().reset_index().set_index(['date']) | |
# adjust for valentines day shift to measurement methodology | |
df['confirmed'] = df.apply(correct, key='confirmed', axis=1) | |
# get the number of daily new infections | |
df['newConf'] = df.confirmed.diff() | |
df.at[df.index.min(), 'newConf'] = df.confirmed.iloc[0] | |
# add history to each row, list index is days since diagnosis | |
# this allows operations on each row, without having to reference prior rows | |
df['confHist'] = np.nan | |
df['confHist'] = df['confHist'].astype(object) | |
for day in df.index: | |
df.at[day, 'confHist'] = df[:day]['newConf'].tolist()[::-1] | |
df = df.reset_index() | |
df['date'] = df.date.dt.strftime('%y/%m/%d') | |
# initial guess | |
dfit_params = Parameters() | |
dfit_params.add('amp', value=3) | |
dfit_params.add('mean', value=14) | |
dfit_params.add('stddev', value=4.2) | |
# calculate predicted values with best fit params | |
dout, df['predDeaths'] = fit('deaths', dfit_params, df) | |
rfit_params = Parameters() | |
rfit_params.add('amp', value=96, max=100-dout.params['amp']) | |
rfit_params.add('mean', value=23) | |
rfit_params.add('stddev', value=10) | |
rout, df['predRecov'] = fit('recovered', rfit_params, df) | |
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(12,10)) | |
plot(df, dout, axs, 0, 'deaths', 'predDeaths') | |
plot(df, rout, axs, 1, 'recovered', 'predRecov') | |
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) | |
plt.suptitle('COVID-19 Outcome Modeling with Normal Distributions') | |
plt.show() | |
# fit data | |
def fit(key, fit_params, df): | |
# set input as daily confirmed history and output as deaths | |
x = df['confHist'].to_list() | |
data = df[key].to_list() | |
#print(x) | |
#print(data) | |
# ordinary least squares | |
out = minimize(residual, fit_params, args=(x,), kws={'data': data}) | |
print(fit_report(out)) | |
# calculate predicted values with best fit params | |
return out, residual(out.params, x) | |
# plot results | |
def plot(df, out, axs, c, lbl, pred): | |
amp = out.params['amp'].value | |
mean = out.params['mean'].value | |
stddev = out.params['stddev'].value | |
x = range(45) | |
y = amp * norm.pdf(x, mean, stddev) | |
sns.lineplot(x,y, ax=axs[0,c]) | |
axs[0,c].set(title='%s Normal Distribution'%(lbl.title()), xlabel='Days after diagnosis', ylabel='%s probability'%(lbl.title())) | |
axs[0,c].yaxis.set_major_formatter(mtick.PercentFormatter()) | |
axs[0,c].text(3,max(y)/2,'amp = %0.2f\n$\mu = %0.2f$\n$\sigma = %0.2f$'%(amp, mean, stddev)) | |
y = [dfunc(amp, mean, stddev, days_since_diagnosis) for days_since_diagnosis in x] | |
sns.lineplot(x,y, ax=axs[1,c]) | |
axs[1,c].set(title='%s Cumulative Distribution Function'%(lbl.title()), xlabel='Days after diagnosis', ylabel='%s CDF'%(lbl.title())) | |
axs[1,c].yaxis.set_major_formatter(mtick.PercentFormatter()) | |
sns.barplot(ax=axs[2,c], x='date', y=lbl, color='firebrick', data=df) | |
axs[2,c].set(title='Predicted vs Actual %s count'%lbl, xlabel='Date', ylabel='%s count'%(lbl.title())) | |
ax3a = axs[2,c].twinx() | |
ax3a = sns.lineplot(x='date', y=pred, color='green', data=df, dashes=True) | |
ax3a.set_yticks([]) | |
ax3a.set(xlabel='', ylabel='') | |
ax3a.set_ylim(0,) | |
ax3a.xaxis.set_major_locator(mtick.MultipleLocator(7)) | |
# normal distribution | |
def dfunc(amp, mean, stddev, x): | |
return amp * norm.cdf(x, mean, stddev) | |
# residual calculation | |
def residual(pars, x, data=None): | |
vals = pars.valuesdict() | |
amp = vals['amp'] | |
mean = vals['mean'] | |
stddev = vals['stddev'] | |
model = [] | |
for row in x: | |
model.append(sum([dfunc(amp, mean, stddev, i)/100 * c for i,c in enumerate(row)])) | |
if data is None: | |
return model | |
return np.array(model) - np.array(data) | |
# adjust pre-valentines data to match change in classification method, to reflect current data | |
def correct(row, key=None): | |
if row.name < datetime.datetime(2020, 2, 12): | |
return row[key] * 1.25 | |
if row.name == datetime.datetime(2020, 2, 12): | |
return row[key] * 1.3 | |
if row.name == datetime.datetime(2020, 2, 13): | |
return row[key] * 1.05 | |
return row[key] | |
def readData(): | |
columns = ['date','Province/State','Country/Region','Lat','Long'] | |
df = pd.DataFrame(columns=columns) | |
for key in urls.keys(): | |
df1 = pd.read_csv(urls[key],parse_dates=True) | |
df1 = df1.melt(id_vars=['Province/State','Country/Region','Lat','Long'], var_name='date', value_name=key) | |
df1['date'] = pd.to_datetime(df1.date) | |
df = df.merge(df1, how='outer', on=columns) | |
return df | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment