Last active
October 26, 2024 11:46
-
-
Save rcsmit/afc46478b9dea858aa505e614c45aed9 to your computer and use it in GitHub Desktop.
donkeyshot_simulation_reproduction_ratio.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.integrate import odeint | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from mpl_toolkits.mplot3d import Axes3D | |
## https://github.com/donkeyshot/covid-mitigation/ | |
# translated to Python by ClaudeAI : https://claude.ai/chat/8fff8b6c-db02-413d-951f-e23d36e5cbee | |
# ======https://github.com/donkeyshot/covid-mitigation/blob/master/R/simulation%20model.R ========= | |
def covid_model(vars, t, parms): | |
""" | |
ODE model for COVID-19 transmission dynamics | |
""" | |
S, E, I1, I2m, I3m, I2s, I3s, Y2, Y3, Inf, Rep = vars | |
# Unpack parameters | |
popsize = parms['popsize'] | |
beta1 = parms['beta1'] | |
beta2 = parms['beta2'] | |
beta3 = parms['beta3'] | |
sigma = parms['sigma'] | |
gamma1 = parms['gamma1'] | |
gamma2 = parms['gamma2'] | |
gamma3 = parms['gamma3'] | |
alpha = parms['alpha'] | |
p_par = parms['p_par'] | |
b_par = parms['b_par'] | |
r_par = parms['r_par'] | |
# Calculate force of infection | |
foi = (beta1 * I1 + | |
beta2 * I2m + beta3 * I3m + | |
beta2 * b_par * I2s + b_par * beta3 * I3s + | |
beta2 * r_par * Y2 + beta3 * r_par * Y3) / popsize | |
# Differential equations | |
dS = -foi * S | |
dE = foi * S - sigma * E | |
dI1 = sigma * E - gamma1 * I1 | |
dI2m = gamma1 * I1 * (1 - p_par) - gamma2 * I2m | |
dI3m = gamma2 * I2m - gamma3 * I3m | |
dI2s = gamma1 * I1 * p_par - (alpha + gamma2) * I2s | |
dI3s = gamma2 * I2s - gamma3 * I3s | |
dY2 = alpha * I2s - gamma2 * Y2 | |
dY3 = gamma2 * Y2 - gamma3 * Y3 | |
dInf = foi * S | |
dRep = alpha * I2s | |
return [dS, dE, dI1, dI2m, dI3m, dI2s, dI3s, dY2, dY3, dInf, dRep] | |
def covid_sim(popsize=60e6, seedsize=10, burnin_time=30, timewindow=365, | |
beta1=0.25, beta2=0.16, beta3=0.016, | |
sigma=1, gamma1=0.2, gamma2=0.14, gamma3=0.14, | |
alpha=0.5, p_par=0.5, b_par=0.8, r_par=0.01, | |
start_dist=60, end_dist=108, effect_dist=1, | |
**kwargs): | |
""" | |
Function to run COVID-19 simulations with specified parameters | |
""" | |
# Initialize parameters | |
parameters = { | |
'popsize': popsize, 'beta1': beta1, 'beta2': beta2, 'beta3': beta3, | |
'sigma': sigma, 'gamma1': gamma1, 'gamma2': gamma2, 'gamma3': gamma3, | |
'alpha': alpha, 'p_par': p_par, 'b_par': b_par, 'r_par': b_par | |
} | |
# Update parameters with any additional kwargs | |
parameters.update(kwargs) | |
# Initial state | |
initial_state = [ | |
popsize, seedsize, 0, 0, 0, 0, 0, 0, 0, 0, 0 | |
] | |
# Burnin simulation | |
t_burnin = np.linspace(0, burnin_time, burnin_time + 1) | |
burnin_res = odeint(covid_model, initial_state, t_burnin, args=(parameters,)) | |
# Main simulation before distancing | |
t_main = np.linspace(0, min(start_dist, timewindow), int(min(start_dist, timewindow)) + 1) | |
main_res = odeint(covid_model, burnin_res[-1], t_main, args=(parameters,)) | |
if timewindow <= start_dist: | |
return pd.DataFrame(main_res, columns=['S', 'E', 'I1', 'I2m', 'I3m', 'I2s', 'I3s', 'Y2', 'Y3', 'Infected', 'Reported']) | |
# Simulation during distancing | |
dist_params = parameters.copy() | |
dist_params['beta1'] *= effect_dist | |
dist_params['beta2'] *= effect_dist | |
dist_params['beta3'] *= effect_dist | |
t_dist = np.linspace(start_dist, min(end_dist, timewindow), | |
int(min(end_dist, timewindow) - start_dist) + 1) | |
dist_res = odeint(covid_model, main_res[-1], t_dist, args=(dist_params,)) | |
if timewindow <= end_dist: | |
combined_res = np.vstack((main_res[:-1], dist_res)) | |
return pd.DataFrame(combined_res, columns=['S', 'E', 'I1', 'I2m', 'I3m', 'I2s', 'I3s', 'Y2', 'Y3', 'Infected', 'Reported']) | |
# Simulation after distancing | |
t_post = np.linspace(end_dist, timewindow, int(timewindow - end_dist) + 1) | |
post_res = odeint(covid_model, dist_res[-1], t_post, args=(parameters,)) | |
combined_res = np.vstack((main_res[:-1], dist_res[:-1], post_res)) | |
return pd.DataFrame(combined_res, columns=['S', 'E', 'I1', 'I2m', 'I3m', 'I2s', 'I3s', 'Y2', 'Y3', 'Infected', 'Reported']) | |
def plot_scenarios(scenario1=None, scenario2=None, scenario3=None): | |
""" | |
Function to plot three scenarios of the COVID-19 simulation | |
""" | |
if scenario1 is None: | |
scenario1 = {} | |
if scenario2 is None: | |
scenario2 = {} | |
if scenario3 is None: | |
scenario3 = {} | |
# Run simulations | |
baseline = covid_sim(**scenario1) | |
alt1 = covid_sim(**{**scenario1, **scenario2}) | |
alt2 = covid_sim(**{**scenario1, **scenario3}) | |
# Prepare data for plotting | |
baseline['scenario'] = 'baseline' | |
baseline['time'] = range(len(baseline)) | |
alt1['scenario'] = 'scenario_1' | |
alt1['time'] = range(len(alt1)) | |
alt2['scenario'] = 'scenario_2' | |
alt2['time'] = range(len(alt2)) | |
# Combine results | |
all_plots = pd.concat([baseline, alt1, alt2]) | |
# Calculate new infections and reports | |
all_plots['Infection'] = all_plots.groupby('scenario')['Infected'].diff() | |
all_plots['Reporting'] = all_plots.groupby('scenario')['Reported'].diff() | |
# Create plot | |
plt.figure(figsize=(16, 10)) | |
sns.set_style("whitegrid") | |
for scenario in ['baseline', 'scenario_1', 'scenario_2']: | |
data = all_plots[all_plots['scenario'] == scenario] | |
plt.plot(data['time'], data['Reporting'], linewidth=2, label=scenario) | |
plt.xlabel('Months since transmission established') | |
plt.ylabel('Rate of cases being reported') | |
plt.xticks(np.arange(0, 371, 30), range(13)) | |
plt.legend() | |
return plt | |
# Example usage for creating the main figure | |
def donkeyshot_simulation_create_main_figure(): | |
scenarios = { | |
'scenario1': {}, | |
'scenario2': {'effect_dist': 0.75, 'start_dist': 70, 'end_dist': 365}, | |
'scenario3': {'effect_dist': 0.5, 'start_dist': 80, 'end_dist': 200} | |
} | |
fig = plot_scenarios(**scenarios) | |
# Add annotations | |
annotations = [ | |
(130, 3.5e5, f"Timing and width of peak uncertain due to\n" | |
"- Stochasticity in early dynamics\n" | |
"- Heterogeneities in contact patterns\n" | |
"- Spatial variation\n" | |
"- Uncertainty in key epidemiological parameters\n" | |
f"{scenarios}"), | |
(140, 2.1e5, "Social distancing\nflattens curve"), | |
(155, 1.25e5, "Risk of resurgence\n following lifting of\n interventions"), | |
(0, 0.6e5, "Epidemic growth,\ndoubling time\n~4-7 days") | |
] | |
for x, y, text in annotations: | |
plt.annotate(text, (x, y), xytext=(5, 5), textcoords='offset points') | |
# Add arrow | |
plt.annotate('', xy=(200, 0.12e5), xytext=(205, 0.67e5), | |
arrowprops=dict(arrowstyle='->')) | |
plt.show() | |
plt.savefig('fig_main.png', dpi=300, bbox_inches='tight') #file is empty | |
return fig | |
# =========== https://github.com/donkeyshot/covid-mitigation/blob/master/R/reproduction%20ratio.R =========================== | |
def covid_r0(beta1=0.25, beta2=0.16, beta3=0.016, | |
sigma=1, gamma1=0.2, gamma2=0.14, gamma3=0.14, | |
alpha=0.5, p=0.5, b=0.8, r=0.01): | |
""" | |
Calculate the basic reproduction number (R0) for the COVID-19 model | |
""" | |
r0 = ( | |
beta1 / gamma1 + # from I1 | |
beta2 * (1 - p) / gamma2 + # from I2m | |
beta3 * (1 - p) / gamma3 + # from I3m | |
beta2 * b * p / (gamma2 + alpha) + # from I2s | |
b * beta3 * p * (gamma2 / (gamma2 + alpha)) / gamma3 + # from I3s | |
beta2 * r * p * (alpha / (gamma2 + alpha)) / gamma2 + | |
beta3 * r * p * (alpha / (gamma2 + alpha)) / gamma3 | |
) | |
return r0 | |
def calculate_r0_surface(x_range, y_range, **fixed_params): | |
""" | |
Calculate R0 values for a range of isolation delays and proportion isolated | |
""" | |
X, Y = np.meshgrid(x_range, y_range) | |
Z = np.zeros_like(X) | |
for i in range(len(x_range)): | |
for j in range(len(y_range)): | |
Z[j, i] = covid_r0(alpha=1/x_range[i], p=y_range[j], **fixed_params) | |
return X, Y, Z | |
def plot_r0_surface(save_path=None, dpi=300): | |
""" | |
Create and optionally save a 3D surface plot of R0 values | |
""" | |
# Define ranges for isolation delay and proportion isolated | |
x_range = np.arange(0.5, 10.5, 0.5) # isolation delay | |
y_range = np.arange(0, 1.05, 0.05) # proportion isolated | |
# Fixed parameters for R0 calculation | |
fixed_params = { | |
'beta1': 0.25, | |
'beta2': 0.16, | |
'beta3': 0.016, | |
'gamma1': 0.2, | |
'gamma2': 0.14, | |
'gamma3': 0.14, | |
'b': 0.8, | |
'r': 0.01 | |
} | |
# Calculate surface values | |
X, Y, Z = calculate_r0_surface(x_range, y_range, **fixed_params) | |
# Create figure | |
fig = plt.figure(figsize=(10, 8)) | |
ax = fig.add_subplot(111, projection='3d') | |
# Create surface plot | |
surf = ax.plot_surface(X, Y, Z, | |
cmap=cm.coolwarm, | |
antialiased=True, | |
vmin=1, | |
vmax=np.max(Z)) | |
# Customize plot | |
ax.view_init(elev=30, azim=210) # Similar viewing angle to R code | |
ax.set_xlabel('Isolation delay (1/alpha)') | |
ax.set_ylabel('Proportion isolated') | |
ax.set_zlabel('Reproductive number') | |
ax.set_xlim(0, 10) | |
ax.set_zlim(1, np.max(Z)) | |
# Add color bar | |
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5) | |
# Adjust layout | |
plt.tight_layout() | |
# Save if path provided | |
if save_path: | |
plt.savefig(save_path, dpi=dpi, bbox_inches='tight') | |
plt.close() | |
else: | |
plt.show() | |
return fig | |
def create_publication_figure(save_path='fig_appendix3_.png', dpi=300): | |
""" | |
Create a publication-ready version of the R0 surface plot | |
""" | |
# Set up figure with minimal decorations | |
fig = plt.figure(figsize=(8, 8)) | |
ax = fig.add_subplot(111, projection='3d') | |
# Calculate surface | |
x_range = np.arange(0.5, 10.5, 0.5) | |
y_range = np.arange(0, 1.05, 0.05) | |
X, Y, Z = calculate_r0_surface(x_range, y_range) | |
# Create surface plot | |
surf = ax.plot_surface(X, Y, Z, | |
cmap=cm.coolwarm, | |
antialiased=True, | |
vmin=1, | |
vmax=np.max(Z)) | |
# Customize view and appearance | |
ax.view_init(elev=30, azim=210) | |
ax.set_xlim(0, 10) | |
ax.set_zlim(1, np.max(Z)) | |
# Remove labels for publication version | |
ax.set_xlabel('') | |
ax.set_ylabel('') | |
ax.set_zlabel('') | |
# Adjust margins | |
plt.subplots_adjust(left=0, right=1, bottom=0, top=1) | |
# Save figure | |
plt.savefig(save_path, dpi=dpi, bbox_inches='tight') | |
plt.close() | |
return fig | |
# Example usage: | |
def donkeyshot_reproduction(): | |
# Create interactive plot | |
plot_r0_surface() | |
# Create publication figure | |
create_publication_figure('fig_appendix3__.png', dpi=300) | |
# Calculate specific R0 value | |
r0_value = covid_r0(alpha=0.5, p=0.5) | |
print(f"R0 value for alpha=0.5, p=0.5: {r0_value:.2f}") | |
def main(): | |
donkeyshot_simulation_create_main_figure() | |
donkeyshot_reproduction() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment