Created
March 14, 2021 15:50
-
-
Save sdwfrost/127c0053b56e0a371cc9b6e80480b77e to your computer and use it in GitHub Desktop.
SIR example in Bridge.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Stochastic differential equation model using Bridge.jl | |
Simon Frost (@sdwfrost), 2021-03-13 | |
## Introduction | |
A stochastic differential equation version of the SIR model is: | |
- Stochastic | |
- Continuous in time | |
- Continuous in state | |
This implementation uses `Bridge.jl`, and is modified from [here](http://www.math.chalmers.se/~smoritz/journal/2018/01/19/parameter-inference-for-a-simple-sir-model/). | |
## Libraries | |
```julia | |
using Bridge | |
using StaticArrays | |
using Random | |
using DataFrames | |
using StatsPlots | |
using BenchmarkTools | |
``` | |
## Transitions | |
`Bridge.jl` uses structs and multiple dispatch, so I first have to write a struct that inherits from `Bridge.ContinuousTimeProcess`, giving the number of states (3) and their type, along with parameter values and their type. | |
```julia | |
struct SIR <: ContinuousTimeProcess{SVector{3,Float64}} | |
β::Float64 | |
c::Float64 | |
γ::Float64 | |
end | |
``` | |
I now define the function `Bridge.b` to take this struct and return a static vector (`@SVector`) of the derivatives of S, I, and R. | |
```julia | |
function Bridge.b(t, u, P::SIR) | |
(S,I,R) = u | |
N = S + I + R | |
dS = -P.β*P.c*S*I/N | |
dI = P.β*P.c*S*I/N - P.γ*I | |
dR = P.γ*I | |
return @SVector [dS,dI,dR] | |
end | |
``` | |
```julia | |
function Bridge.σ(t, u, P::SIR) | |
(S,I,R) = u | |
N = S + I + R | |
ifrac = P.β*P.c*I/N*S | |
rfrac = P.γ*I | |
return @SMatrix Float64[ | |
sqrt(ifrac) 0.0 | |
-sqrt(ifrac) -sqrt(rfrac) | |
0.0 sqrt(rfrac) | |
] | |
end | |
``` | |
## Time domain | |
```julia | |
δt = 0.1 | |
tmax = 40.0 | |
tspan = (0.0,tmax) | |
ts = 0.0:δt:tmax; | |
``` | |
## Initial conditions | |
```julia | |
u0 = @SVector [990.0,10.0,0.0]; # S,I,R | |
``` | |
## Parameter values | |
```julia | |
p = [0.05,10.0,0.25]; # β,c,γ | |
``` | |
## Random number seed | |
```julia | |
Random.seed!(1234); | |
``` | |
## Running the model | |
Set up object. | |
```julia | |
prob = SIR(p...); | |
``` | |
Generate noise. | |
```julia | |
W = sample(ts, Wiener{SVector{2,Float64}}()); | |
``` | |
```julia | |
sol = solve(Bridge.EulerMaruyama(), u0, W, prob); | |
``` | |
## Post-processing | |
We can convert the output to a dataframe for convenience. | |
```julia | |
df_sde = DataFrame(Bridge.mat(sol.yy)') | |
df_sde[!,:t] = ts; | |
``` | |
## Plotting | |
We can now plot the results. | |
```julia | |
@df df_sde plot(:t, | |
[:x1 :x2 :x3], | |
label=["S" "I" "R"], | |
xlabel="Time", | |
ylabel="Number") | |
``` | |
## Benchmarking | |
```julia | |
@benchmark begin | |
W = sample(ts, Wiener{SVector{2,Float64}}()); | |
solve(Bridge.EulerMaruyama(), u0, W, prob); | |
end | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment