Skip to content

Instantly share code, notes, and snippets.

@pjpetersik
Created April 22, 2022 13:53
Show Gist options
  • Save pjpetersik/f6d4549f05a8b4f64773174eabbece0f to your computer and use it in GitHub Desktop.
Save pjpetersik/f6d4549f05a8b4f64773174eabbece0f to your computer and use it in GitHub Desktop.
import epipack as epk
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
def create_growth_rate_function(r, K):
def growth_rate(t, y):
N = np.sum(y)
return r * N * (1 - N / K)
return growth_rate
def create_infection_rate_function(low, high, K, N_high):
if not (0 < N_high < K):
raise Exception("N_high must be greater 0 and smaler K.")
m1 = (high - low) / N_high
n1 = low
m2 = (low - high) / (K - N_high)
n2 = low - m2 * K
def infection_rate(t, y):
N = np.sum(y)
if N <= N_high:
return m1 * N + n1
else:
return m2 * N + n2
return infection_rate
S, I, R = list("SIR")
K = 40
r = 0.1
S0 = 0
I0 = 1
growth_rate = create_growth_rate_function(
r=r,
K=K
)
infection_rate = create_infection_rate_function(
low=0.0,
high=0.5,
N_high=20,
K=K
)
removing_rate = 0.05
model = epk.EpiModel(
compartments=[S,I,R],
initial_population_size=S0+I0,
correct_for_dynamical_population_size=True
)\
.set_processes(
[
(None, growth_rate, S),
(S, I, infection_rate, I, I),
(I, removing_rate, R),
],
allow_nonzero_column_sums=True,
)\
.set_initial_conditions({S: S0, I: I0})
tmin = 0
tmax = 100
timesteps = 1000
tarr = np.linspace(tmin, tmax, timesteps)
result = model.integrate(tarr)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6))
# plot dynamics of the model on the left side
ax1.set_xlim(tmin, tmax)
ax1.set_ylim(0, 1.1*K)
N = np.zeros_like(result["S"])
for compartment, incidence in result.items():
ax1.plot(tarr, incidence, label=compartment)
N += incidence
ax1.plot(tarr, N, label="N", color='k')
ax1.set_xlabel('time [days]')
ax1.set_ylabel('incidence')
ax1.legend()
# plot the infection rate function on the right side
Narr = np.arange(0, K+1, 0.1)
rate_arr = [infection_rate(0, value) for value in Narr]
ax2.set_xlim(0, K)
ax2.set_ylim(0, 0.6)
ax2.set_xlabel('N')
ax2.set_ylabel('infection rate')
ax2.plot(Narr, rate_arr)
plt.savefig("img.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment