Skip to content

Instantly share code, notes, and snippets.

@ctung
Last active March 6, 2020 13:41
Show Gist options
  • Save ctung/b31726c64e55b7ce48887f98b52c6acf to your computer and use it in GitHub Desktop.
Save ctung/b31726c64e55b7ce48887f98b52c6acf to your computer and use it in GitHub Desktop.
covid-19 death probability function
#! python3
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import datetime
import numpy as np
from lmfit import Parameters, fit_report, minimize
from pprint import pprint
from scipy.stats import norm
sns.set()
urls = {
"confirmed": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv",
"deaths": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Deaths.csv",
"recovered": "https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Recovered.csv"
}
def main():
df = readData()
# merge countries into global total
df = df[['date','deaths','recovered','confirmed']].groupby(['date']).sum().reset_index().set_index(['date'])
# adjust for valentines day shift to measurement methodology
df['confirmed'] = df.apply(correct, key='confirmed', axis=1)
# get the number of daily new infections
df['newConf'] = df.confirmed.diff()
df.at[df.index.min(), 'newConf'] = df.confirmed.iloc[0]
# add history to each row, list index is days since diagnosis
# this allows operations on each row, without having to reference prior rows
df['confHist'] = np.nan
df['confHist'] = df['confHist'].astype(object)
for day in df.index:
df.at[day, 'confHist'] = df[:day]['newConf'].tolist()[::-1]
df = df.reset_index()
df['date'] = df.date.dt.strftime('%y/%m/%d')
# initial guess
dfit_params = Parameters()
dfit_params.add('amp', value=3)
dfit_params.add('mean', value=14)
dfit_params.add('stddev', value=4.2)
# calculate predicted values with best fit params
dout, df['predDeaths'] = fit('deaths', dfit_params, df)
rfit_params = Parameters()
rfit_params.add('amp', value=96, max=100-dout.params['amp'])
rfit_params.add('mean', value=23)
rfit_params.add('stddev', value=10)
rout, df['predRecov'] = fit('recovered', rfit_params, df)
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(12,10))
plot(df, dout, axs, 0, 'deaths', 'predDeaths')
plot(df, rout, axs, 1, 'recovered', 'predRecov')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.suptitle('COVID-19 Outcome Modeling with Normal Distributions')
plt.show()
# fit data
def fit(key, fit_params, df):
# set input as daily confirmed history and output as deaths
x = df['confHist'].to_list()
data = df[key].to_list()
#print(x)
#print(data)
# ordinary least squares
out = minimize(residual, fit_params, args=(x,), kws={'data': data})
print(fit_report(out))
# calculate predicted values with best fit params
return out, residual(out.params, x)
# plot results
def plot(df, out, axs, c, lbl, pred):
amp = out.params['amp'].value
mean = out.params['mean'].value
stddev = out.params['stddev'].value
x = range(45)
y = amp * norm.pdf(x, mean, stddev)
sns.lineplot(x,y, ax=axs[0,c])
axs[0,c].set(title='%s Normal Distribution'%(lbl.title()), xlabel='Days after diagnosis', ylabel='%s probability'%(lbl.title()))
axs[0,c].yaxis.set_major_formatter(mtick.PercentFormatter())
axs[0,c].text(3,max(y)/2,'amp = %0.2f\n$\mu = %0.2f$\n$\sigma = %0.2f$'%(amp, mean, stddev))
y = [dfunc(amp, mean, stddev, days_since_diagnosis) for days_since_diagnosis in x]
sns.lineplot(x,y, ax=axs[1,c])
axs[1,c].set(title='%s Cumulative Distribution Function'%(lbl.title()), xlabel='Days after diagnosis', ylabel='%s CDF'%(lbl.title()))
axs[1,c].yaxis.set_major_formatter(mtick.PercentFormatter())
sns.barplot(ax=axs[2,c], x='date', y=lbl, color='firebrick', data=df)
axs[2,c].set(title='Predicted vs Actual %s count'%lbl, xlabel='Date', ylabel='%s count'%(lbl.title()))
ax3a = axs[2,c].twinx()
ax3a = sns.lineplot(x='date', y=pred, color='green', data=df, dashes=True)
ax3a.set_yticks([])
ax3a.set(xlabel='', ylabel='')
ax3a.set_ylim(0,)
ax3a.xaxis.set_major_locator(mtick.MultipleLocator(7))
# normal distribution
def dfunc(amp, mean, stddev, x):
return amp * norm.cdf(x, mean, stddev)
# residual calculation
def residual(pars, x, data=None):
vals = pars.valuesdict()
amp = vals['amp']
mean = vals['mean']
stddev = vals['stddev']
model = []
for row in x:
model.append(sum([dfunc(amp, mean, stddev, i)/100 * c for i,c in enumerate(row)]))
if data is None:
return model
return np.array(model) - np.array(data)
# adjust pre-valentines data to match change in classification method, to reflect current data
def correct(row, key=None):
if row.name < datetime.datetime(2020, 2, 12):
return row[key] * 1.25
if row.name == datetime.datetime(2020, 2, 12):
return row[key] * 1.3
if row.name == datetime.datetime(2020, 2, 13):
return row[key] * 1.05
return row[key]
def readData():
columns = ['date','Province/State','Country/Region','Lat','Long']
df = pd.DataFrame(columns=columns)
for key in urls.keys():
df1 = pd.read_csv(urls[key],parse_dates=True)
df1 = df1.melt(id_vars=['Province/State','Country/Region','Lat','Long'], var_name='date', value_name=key)
df1['date'] = pd.to_datetime(df1.date)
df = df.merge(df1, how='outer', on=columns)
return df
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment