|
import jax.profiler |
|
jax.profiler.start_server(9999) |
|
import numpy as onp |
|
import jax.numpy as jnp |
|
from functools import partial |
|
from jax import random |
|
from jax.nn.initializers import (xavier_normal, xavier_uniform, glorot_normal, glorot_uniform, uniform, |
|
normal, lecun_uniform, lecun_normal,kaiming_uniform,kaiming_normal) |
|
|
|
from jax.nn import (softplus, selu,gelu,glu,swish,relu,relu6,elu,sigmoid, swish) |
|
from jax import vmap, grad, partial, pmap, value_and_grad, jit |
|
|
|
from jax.experimental.ode import odeint |
|
|
|
coupling_matrix_ = onp.load('./coupling_matrix.npy') |
|
epi_array_ = onp.load('./epi_array.npy') |
|
mobilitypopulation_array_scaled_ = onp.load('./mobilitypopulation_array_scaled.npy') |
|
coupling_matrix = jnp.asarray(coupling_matrix_) |
|
epi_array = jnp.asarray(epi_array_) |
|
mobilitypopulation_array_scaled = jnp.asarray(mobilitypopulation_array_scaled_) |
|
|
|
def inv_softplus(x): |
|
return x+jnp.log(-jnp.expm1(-x)) |
|
|
|
key = random.PRNGKey(0) |
|
layers = [7, 14, 14, 7, 1] |
|
activations = [swish, swish, swish, softplus] |
|
weight_initializer = kaiming_uniform |
|
bias_initializer = normal |
|
|
|
def init_layers(nn_layers,nn_weight_initializer_, |
|
nn_bias_initializer_): |
|
init_w = weight_initializer() |
|
init_b = bias_initializer() |
|
params = [] |
|
for in_, out_ in zip(layers[:-1],layers[1:]): |
|
key = random.PRNGKey(in_) |
|
weights = init_w(key,(in_,out_)).reshape((in_*out_,)) |
|
biases = init_b(key,(out_,)) |
|
params_ = jnp.concatenate((weights,biases)) |
|
params.append(params_) |
|
return jnp.concatenate(params) |
|
|
|
def nnet(nn_layers, nn_activations, nn_params, x): |
|
n_s = 0 |
|
x_in = jnp.expand_dims(x,axis=1) # |
|
#x_in = x.reshape(len(x),1) |
|
for in_,out_, act_ in zip(nn_layers[:-1],nn_layers[1:],nn_activations): |
|
n_w = in_*out_ |
|
n_b = out_ |
|
n_t = n_w+n_b |
|
weights = nn_params[n_s:n_s+n_w].reshape((out_,in_)) |
|
biases = jnp.expand_dims(nn_params[n_s+n_w:n_s+n_t],axis=1) |
|
x_in = act_(jnp.matmul(weights,x_in)+biases) |
|
n_s += n_t |
|
|
|
return x_in |
|
|
|
nn = jit(partial(nnet, layers,activations)) |
|
nn_batch = vmap(partial(nnet,layers,activations), (None,0),0) |
|
#nn_batch=partial(nnet, layers,activations) |
|
|
|
p_net = init_layers(layers,weight_initializer,bias_initializer) |
|
|
|
# county-wise learnable scaling factors |
|
n_counties = coupling_matrix.shape[0] |
|
init_b = bias_initializer() |
|
p_scaling = softplus(200*init_b(key,(n_counties,))) |
|
|
|
def SEIRD_mobility_coupled(u, t, p_, mobility_, coupling_matrix_): |
|
s, e, id1, id2, id3, id4, id5, id6, id7, d, ir1, ir2, ir3, ir4, ir5, r = u |
|
κ, α, γ = softplus(p_[:3]) |
|
# κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1 |
|
η = - jnp.log(-jnp.expm1(-κ*α))/(γ+1.0e-8) |
|
ind = jnp.rint(t.astype(jnp.float32)) |
|
n_c = coupling_matrix_.shape[0] |
|
scaler_ = softplus(p_[3:3+n_c]) |
|
cm_ = jnp.expand_dims(scaler_,(1))*coupling_matrix_[...,ind.astype(jnp.int32)] |
|
β = nn_batch(p_[3+n_c:], mobility_[...,ind.astype(jnp.int32)])[:,0,0] |
|
i = id1+id2+id3+ir1+ir2+ir3+ir4+ir5 |
|
|
|
a = β*s*i+β*s*(jnp.matmul(i,cm_.T)+jnp.matmul(cm_,i)) |
|
ds = -a |
|
de = a - κ*α*e - γ*η*e |
|
|
|
d_id1 = κ*(α*e-id1) |
|
d_id2 = κ*(id1-id2) |
|
d_id3 = κ*(id2-id3) |
|
d_id4 = κ*(id3-id4) |
|
d_id5 = κ*(id4-id5) |
|
d_id6 = κ*(id5-id6) |
|
d_id7 = κ*(id6-id7) |
|
d_d = κ*id7 |
|
|
|
d_ir1 = γ*(η*e-ir1) |
|
d_ir2 = γ*(ir1-ir2) |
|
d_ir3 = γ*(ir2-ir3) |
|
d_ir4 = γ*(ir3-ir4) |
|
d_ir5 = γ*(ir4-ir5) |
|
d_r = γ*ir5 |
|
|
|
return jnp.stack([ds, |
|
de, |
|
d_id1, d_id2, d_id3, d_id4, d_id5, d_id6, d_id7, d_d, |
|
d_ir1 ,d_ir2, d_ir3, d_ir4, d_ir5, d_r]) |
|
|
|
# Initial conditions |
|
ifr = 0.007 |
|
n_counties = epi_array.shape[2] |
|
n = jnp.tile(1.0,(n_counties,)) |
|
ic0 = epi_array[0,0,:] |
|
d0 = epi_array[0,1,:] |
|
r0 = d0/ifr |
|
s0 = n-ic0-r0-d0 |
|
e0 = ic0 |
|
id10=id20=id30=id40=id50=id60=id70=ic0*ifr/7.0 |
|
ir10=ir20=ir30=ir40=ir50=ic0*(1.0-ifr)/5.0 |
|
u0 = jnp.array([s0, |
|
e0, |
|
id10,id20,id30,id40,id50, id60, id70, d0, |
|
ir10,ir20,ir30,ir40,ir50,r0]) |
|
|
|
# ODE Parameters |
|
κ0_ = 0.97 |
|
α0_ = 0.00185 |
|
β0_ = 0.5 |
|
tb_ = 15 |
|
β1_ = 0.4 |
|
γ0_ = 0.24 |
|
|
|
|
|
|
|
p_ode = inv_softplus(jnp.array([κ0_, α0_, γ0_])) |
|
|
|
# Initial model parameters |
|
p_init = jnp.concatenate((p_ode,p_scaling,p_net)) |
|
|
|
t0 = jnp.linspace(0, float(epi_array.shape[0]), int(epi_array.shape[0])+1) |
|
|
|
# LOSS Function |
|
def diff(sol_,data_): |
|
l1 = jnp.square(jnp.ediff1d((1-sol_[:,0])) - data_[:,0]) |
|
l2 = jnp.square(jnp.ediff1d(sol_[:,9]) - data_[:,1]) |
|
return l1+20000*l2 |
|
diff_v = vmap(diff,(2,2)) |
|
|
|
def loss(data_,m_array_, coupling_matrix_, params_): |
|
sol_ = odeint(SEIRD_mobility_coupled, u0, t0, params_, m_array_,coupling_matrix_, |
|
rtol=1e-4, atol=1e-8) |
|
return jnp.sum(diff_v(sol_,data_)) |
|
|
|
loss_ = partial(loss, epi_array,mobilitypopulation_array_scaled,coupling_matrix) |
|
|
|
grad_jit = jit(grad(loss_)) |
|
loss_jit = jit(loss_) |
|
|
|
%timeit loss_jit(p_init).block_until_ready() |
|
# 91.1 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) |
|
|
|
%timeit grad_jit(p_init).block_until_ready() |
|
# 24.6 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each) |