Skip to content

Instantly share code, notes, and snippets.

@sdwfrost
Created March 14, 2021 15:50
Show Gist options
  • Save sdwfrost/127c0053b56e0a371cc9b6e80480b77e to your computer and use it in GitHub Desktop.
Save sdwfrost/127c0053b56e0a371cc9b6e80480b77e to your computer and use it in GitHub Desktop.
SIR example in Bridge.jl
# Stochastic differential equation model using Bridge.jl
Simon Frost (@sdwfrost), 2021-03-13
## Introduction
A stochastic differential equation version of the SIR model is:
- Stochastic
- Continuous in time
- Continuous in state
This implementation uses `Bridge.jl`, and is modified from [here](http://www.math.chalmers.se/~smoritz/journal/2018/01/19/parameter-inference-for-a-simple-sir-model/).
## Libraries
```julia
using Bridge
using StaticArrays
using Random
using DataFrames
using StatsPlots
using BenchmarkTools
```
## Transitions
`Bridge.jl` uses structs and multiple dispatch, so I first have to write a struct that inherits from `Bridge.ContinuousTimeProcess`, giving the number of states (3) and their type, along with parameter values and their type.
```julia
struct SIR <: ContinuousTimeProcess{SVector{3,Float64}}
β::Float64
c::Float64
γ::Float64
end
```
I now define the function `Bridge.b` to take this struct and return a static vector (`@SVector`) of the derivatives of S, I, and R.
```julia
function Bridge.b(t, u, P::SIR)
(S,I,R) = u
N = S + I + R
dS = -P.β*P.c*S*I/N
dI = P.β*P.c*S*I/N - P.γ*I
dR = P.γ*I
return @SVector [dS,dI,dR]
end
```
```julia
function Bridge.σ(t, u, P::SIR)
(S,I,R) = u
N = S + I + R
ifrac = P.β*P.c*I/N*S
rfrac = P.γ*I
return @SMatrix Float64[
sqrt(ifrac) 0.0
-sqrt(ifrac) -sqrt(rfrac)
0.0 sqrt(rfrac)
]
end
```
## Time domain
```julia
δt = 0.1
tmax = 40.0
tspan = (0.0,tmax)
ts = 0.0:δt:tmax;
```
## Initial conditions
```julia
u0 = @SVector [990.0,10.0,0.0]; # S,I,R
```
## Parameter values
```julia
p = [0.05,10.0,0.25]; # β,c,γ
```
## Random number seed
```julia
Random.seed!(1234);
```
## Running the model
Set up object.
```julia
prob = SIR(p...);
```
Generate noise.
```julia
W = sample(ts, Wiener{SVector{2,Float64}}());
```
```julia
sol = solve(Bridge.EulerMaruyama(), u0, W, prob);
```
## Post-processing
We can convert the output to a dataframe for convenience.
```julia
df_sde = DataFrame(Bridge.mat(sol.yy)')
df_sde[!,:t] = ts;
```
## Plotting
We can now plot the results.
```julia
@df df_sde plot(:t,
[:x1 :x2 :x3],
label=["S" "I" "R"],
xlabel="Time",
ylabel="Number")
```
## Benchmarking
```julia
@benchmark begin
W = sample(ts, Wiener{SVector{2,Float64}}());
solve(Bridge.EulerMaruyama(), u0, W, prob);
end
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment