Skip to content

Instantly share code, notes, and snippets.

@ChrisRackauckas
Last active April 9, 2024 00:57
Show Gist options
  • Save ChrisRackauckas/62a063f23cccf3a55a4ac9f6e497739a to your computer and use it in GitHub Desktop.
Save ChrisRackauckas/62a063f23cccf3a55a4ac9f6e497739a to your computer and use it in GitHub Desktop.
DiffEqFlux.jl (Julia) vs Jax on an Epidemic Model

DiffEqFlux.jl (Julia) vs Jax on an Epidemic Model

The Jax developers optimized a differential equation benchmark in this issue which used DiffEqFlux.jl as a performance baseline. The Julia code from there was updated to include some standard performance tricks and is the benchmark code here. Thus both codes have been optimized by the library developers.

Results

Forward Pass

  • Julia: 75.350 ms
  • Jax: 91.1 ms

Relative Julia speed improvement: 20%

Gradient Calculation

  • Julia: 3.726 s
  • Jax: 24.6 s

Relative Julia speed improvement: 660%

Data Files

The NPZ datafiles required for recreating the benchmarks can be found at: https://github.com/ChrisRackauckas/coupled-epidemiology-models

import jax.profiler
jax.profiler.start_server(9999)
import numpy as onp
import jax.numpy as jnp
from functools import partial
from jax import random
from jax.nn.initializers import (xavier_normal, xavier_uniform, glorot_normal, glorot_uniform, uniform,
normal, lecun_uniform, lecun_normal,kaiming_uniform,kaiming_normal)
from jax.nn import (softplus, selu,gelu,glu,swish,relu,relu6,elu,sigmoid, swish)
from jax import vmap, grad, partial, pmap, value_and_grad, jit
from jax.experimental.ode import odeint
coupling_matrix_ = onp.load('./coupling_matrix.npy')
epi_array_ = onp.load('./epi_array.npy')
mobilitypopulation_array_scaled_ = onp.load('./mobilitypopulation_array_scaled.npy')
coupling_matrix = jnp.asarray(coupling_matrix_)
epi_array = jnp.asarray(epi_array_)
mobilitypopulation_array_scaled = jnp.asarray(mobilitypopulation_array_scaled_)
def inv_softplus(x):
return x+jnp.log(-jnp.expm1(-x))
key = random.PRNGKey(0)
layers = [7, 14, 14, 7, 1]
activations = [swish, swish, swish, softplus]
weight_initializer = kaiming_uniform
bias_initializer = normal
def init_layers(nn_layers,nn_weight_initializer_,
nn_bias_initializer_):
init_w = weight_initializer()
init_b = bias_initializer()
params = []
for in_, out_ in zip(layers[:-1],layers[1:]):
key = random.PRNGKey(in_)
weights = init_w(key,(in_,out_)).reshape((in_*out_,))
biases = init_b(key,(out_,))
params_ = jnp.concatenate((weights,biases))
params.append(params_)
return jnp.concatenate(params)
def nnet(nn_layers, nn_activations, nn_params, x):
n_s = 0
x_in = jnp.expand_dims(x,axis=1) #
#x_in = x.reshape(len(x),1)
for in_,out_, act_ in zip(nn_layers[:-1],nn_layers[1:],nn_activations):
n_w = in_*out_
n_b = out_
n_t = n_w+n_b
weights = nn_params[n_s:n_s+n_w].reshape((out_,in_))
biases = jnp.expand_dims(nn_params[n_s+n_w:n_s+n_t],axis=1)
x_in = act_(jnp.matmul(weights,x_in)+biases)
n_s += n_t
return x_in
nn = jit(partial(nnet, layers,activations))
nn_batch = vmap(partial(nnet,layers,activations), (None,0),0)
#nn_batch=partial(nnet, layers,activations)
p_net = init_layers(layers,weight_initializer,bias_initializer)
# county-wise learnable scaling factors
n_counties = coupling_matrix.shape[0]
init_b = bias_initializer()
p_scaling = softplus(200*init_b(key,(n_counties,)))
def SEIRD_mobility_coupled(u, t, p_, mobility_, coupling_matrix_):
s, e, id1, id2, id3, id4, id5, id6, id7, d, ir1, ir2, ir3, ir4, ir5, r = u
κ, α, γ = softplus(p_[:3])
# κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1
η = - jnp.log(-jnp.expm1(-κ*α))/(γ+1.0e-8)
ind = jnp.rint(t.astype(jnp.float32))
n_c = coupling_matrix_.shape[0]
scaler_ = softplus(p_[3:3+n_c])
cm_ = jnp.expand_dims(scaler_,(1))*coupling_matrix_[...,ind.astype(jnp.int32)]
β = nn_batch(p_[3+n_c:], mobility_[...,ind.astype(jnp.int32)])[:,0,0]
i = id1+id2+id3+ir1+ir2+ir3+ir4+ir5
a = β*s*i+β*s*(jnp.matmul(i,cm_.T)+jnp.matmul(cm_,i))
ds = -a
de = a - κ*α*e - γ*η*e
d_id1 = κ*(α*e-id1)
d_id2 = κ*(id1-id2)
d_id3 = κ*(id2-id3)
d_id4 = κ*(id3-id4)
d_id5 = κ*(id4-id5)
d_id6 = κ*(id5-id6)
d_id7 = κ*(id6-id7)
d_d = κ*id7
d_ir1 = γ*(η*e-ir1)
d_ir2 = γ*(ir1-ir2)
d_ir3 = γ*(ir2-ir3)
d_ir4 = γ*(ir3-ir4)
d_ir5 = γ*(ir4-ir5)
d_r = γ*ir5
return jnp.stack([ds,
de,
d_id1, d_id2, d_id3, d_id4, d_id5, d_id6, d_id7, d_d,
d_ir1 ,d_ir2, d_ir3, d_ir4, d_ir5, d_r])
# Initial conditions
ifr = 0.007
n_counties = epi_array.shape[2]
n = jnp.tile(1.0,(n_counties,))
ic0 = epi_array[0,0,:]
d0 = epi_array[0,1,:]
r0 = d0/ifr
s0 = n-ic0-r0-d0
e0 = ic0
id10=id20=id30=id40=id50=id60=id70=ic0*ifr/7.0
ir10=ir20=ir30=ir40=ir50=ic0*(1.0-ifr)/5.0
u0 = jnp.array([s0,
e0,
id10,id20,id30,id40,id50, id60, id70, d0,
ir10,ir20,ir30,ir40,ir50,r0])
# ODE Parameters
κ0_ = 0.97
α0_ = 0.00185
β0_ = 0.5
tb_ = 15
β1_ = 0.4
γ0_ = 0.24
p_ode = inv_softplus(jnp.array([κ0_, α0_, γ0_]))
# Initial model parameters
p_init = jnp.concatenate((p_ode,p_scaling,p_net))
t0 = jnp.linspace(0, float(epi_array.shape[0]), int(epi_array.shape[0])+1)
# LOSS Function
def diff(sol_,data_):
l1 = jnp.square(jnp.ediff1d((1-sol_[:,0])) - data_[:,0])
l2 = jnp.square(jnp.ediff1d(sol_[:,9]) - data_[:,1])
return l1+20000*l2
diff_v = vmap(diff,(2,2))
def loss(data_,m_array_, coupling_matrix_, params_):
sol_ = odeint(SEIRD_mobility_coupled, u0, t0, params_, m_array_,coupling_matrix_,
rtol=1e-4, atol=1e-8)
return jnp.sum(diff_v(sol_,data_))
loss_ = partial(loss, epi_array,mobilitypopulation_array_scaled,coupling_matrix)
grad_jit = jit(grad(loss_))
loss_jit = jit(loss_)
%timeit loss_jit(p_init).block_until_ready()
# 91.1 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit grad_jit(p_init).block_until_ready()
# 24.6 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
cd(@__DIR__)
using NPZ
using DiffEqSensitivity, DiffEqFlux, OrdinaryDiffEq, Plots
using Random, Distributions
using BenchmarkTools
using Zygote
using ComponentArrays
using Parameters: @unpack
using Octavian
coupling_matrix = npzread("./coupling_matrix.npy")
epi_array = npzread("./epi_array.npy")
mobilitypopulation_array_scaled = npzread("./mobilitypopulation_array_scaled.npy")
# batch dimension for FastDense layer is the second dimension
epi_array = permutedims(epi_array,[3,2,1])
mobilitypopulation_array_scaled = permutedims(mobilitypopulation_array_scaled,[2,1,3]);
println(size(coupling_matrix))
println(size(epi_array))
println(size(mobilitypopulation_array_scaled))
nn = FastChain(FastDense(7, 7, swish),
FastDense(7, 7, swish),
FastDense(7, 7, swish),
FastDense(7, 1, softplus));
# destructure neural network paraemeter into a vector. Conveniently, this matches with what we do in the JAX code
p0_nnet = initial_params(nn)
n_weights = length(p0_nnet);
nn(mobilitypopulation_array_scaled[:,:,1], p0_nnet);
Random.seed!(123)
d = Normal()
n_counties = size(coupling_matrix,1)
p0_scaler = rand(d, n_counties);
function SEIRD_mobility_coupled_outer(mobility_, coupling_matrix_, nn_)
cm_cache = coupling_matrix_[:,:,1]
function SEIRD_mobility_coupled_inner(du,
u,
p_, t)
s = @view u[:,1]
e = @view u[:,2]
id1 = @view u[:,3]
id2 = @view u[:,4]
id3 = @view u[:,5]
id4 = @view u[:,6]
id5 = @view u[:,7]
id6 = @view u[:,8]
id7 = @view u[:,9]
d = @view u[:,10]
ir1 = @view u[:,11]
ir2 = @view u[:,12]
ir3 = @view u[:,13]
ir4 = @view u[:,14]
ir5 = @view u[:,15]
r = @view u[:,16]
ds = @view du[:,1]
de = @view du[:,2]
did1 = @view du[:,3]
did2 = @view du[:,4]
did3 = @view du[:,5]
did4 = @view du[:,6]
did5 = @view du[:,7]
did6 = @view du[:,8]
did7 = @view du[:,9]
dd = @view du[:,10]
dir1 = @view du[:,11]
dir2 = @view du[:,12]
dir3 = @view du[:,13]
dir4 = @view du[:,14]
dir5 = @view du[:,15]
dr = @view du[:,16]
coupling = @view coupling_matrix_[:,:,Int32(round(t+1.0))]
cur_mobility = @view mobility_[:,:,Int32(round(t+1.0))]
κ, α, γ = softplus.(p_[1:3])
# κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1
η = - log(-expm1(-κ*α))/(γ+1.0e-8)
n_c = size(coupling_matrix_,1)
scaler_ = softplus.(p_[4:3+n_c])
p_nnet = p_[4+n_c:end]
β = vec(nn_(cur_mobility, p_nnet))
i = @. id1+id2+id3+ir1+ir2+ir3+ir4+ir5
if eltype(u) <: AbstractFloat
cm_cache .= scaler_ .* coupling
c1 = vec(matmul(reshape(i,1,:),transpose(cm_cache)))
c2 = matmul(cm_cache,i)
else
cm_ = scaler_ .* coupling
c1 = vec(reshape(i,1,:)*transpose(cm_))
c2 = cm_*i
end
a = @. β * s * i + β * s * (c1+c2)
@. ds = -a
@. de = a - κ*α*e - γ*η*e
@. did1 = κ*(α*e-id1)
@. did2 = κ*(id1-id2)
@. did3 = κ*(id2-id3)
@. did4 = κ*(id3-id4)
@. did5 = κ*(id4-id5)
@. did6 = κ*(id5-id6)
@. did7 = κ*(id6-id7)
@. dd = κ*id7
@. dir1 = γ*(η*e-ir1)
@. dir2 = γ*(ir1-ir2)
@. dir3 = γ*(ir2-ir3)
@. dir4 = γ*(ir3-ir4)
@. dir5 = γ*(ir4-ir5)
@. dr = γ*ir5
end
end
SEIRD_mobility_coupled = SEIRD_mobility_coupled_outer(mobilitypopulation_array_scaled,
coupling_matrix, nn);
κ0_ = 0.97
α0_ = 0.00185
β0_ = 0.5
tb_ = 15
β1_ = 0.4
γ0_ = 0.24
inv_softplus(x) = x+log(-expm1(-x))
p0_ode = inv_softplus.([κ0_, α0_, γ0_]);
ifr = 0.007
n_counties = size(coupling_matrix,1)
n = ones(n_counties)
ic0 = epi_array[:,1,1]
d0 = epi_array[:,2,1]
r0 = d0/ifr
s0 = n-ic0-r0-d0
e0 = ic0
id10=id20=id30=id40=id50=id60=id70=ic0*ifr/7.0
ir10=ir20=ir30=ir40=ir50=ic0*(1.0-ifr)/5.0
u0 = hcat(s0,e0,
id10,id20,id30,id40,id50, id60, id70, d0,
ir10,ir20,ir30,ir40,ir50, r0);
p_init = [p0_ode; p0_scaler; p0_nnet];
t_tot = Float32(size(coupling_matrix,3))
tspan = (0.0f0, t_tot-1.0)
tsteps = 0.0f0:1.0:t_tot-1.0
prob_seird = ODEProblem(SEIRD_mobility_coupled, u0, tspan, p_init)
sol_univ = solve(prob_seird, DP5(),abstol = 1e-6, reltol = 1e-3;saveat=tsteps);
function loss(params_)
sol_ = solve(prob_seird, Tsit5(),p=params_, abstol = 1e-8, reltol = 1e-4,
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),
saveat=tsteps)
sol = Array(sol_)
l1 = sum((diff(sol[:,1,:],dims=2)-epi_array[:,1,2:end]).^2)
l2 = sum((diff(sol[:,10,:],dims=2)-epi_array[:,2,2:end]).^2)
return l1+l2
end
loss(p_init)
Zygote.gradient(loss,p_init)
@btime loss(p_init); # 75.350 ms (75142 allocations: 90.77 MiB)
@benchmark Zygote.gradient(loss,p_init)
#=
BenchmarkTools.Trial:
memory estimate: 501.95 MiB
allocs estimate: 368084
--------------
minimum time: 3.596 s (1.06% GC)
median time: 3.726 s (0.98% GC)
mean time: 3.726 s (0.98% GC)
maximum time: 3.855 s (0.91% GC)
--------------
samples: 2
evals/sample: 1
=#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment