Skip to content

Instantly share code, notes, and snippets.

@tkf
Last active May 3, 2018 02:14
Show Gist options
  • Save tkf/53fd8a0a19dfeb4729215efb2f40e978 to your computer and use it in GitHub Desktop.
Save tkf/53fd8a0a19dfeb4729215efb2f40e978 to your computer and use it in GitHub Desktop.
t0 = 20
i0 = findfirst(x -> x > t0, sol.t)
plt = plot(sol, tspan=(t0, sol.t[end]))
yl = ylims(plt)
scatter!(plt, sol.t[i0:end], sol'[i0:end, :])
ylims!(plt, yl)
plot!(plt, legend=:topleft)
plt
module RNNODE
using DiffEqBase: ODEProblem
using DiffEqCallbacks: PositiveDomain
using Parameters: @unpack, @with_kw
R(x) = erfc(-x / sqrt(2)) / 2
@with_kw struct RNN
W
Q
θ = 1
τ = 1
γ = 1e-5
# Pre-allocated temporary variables:
x = similar(@view W[:, 1])
den = similar(@view W[:, 1])
end
function make_rnn(;
K = 900,
kwargs...)
J = [1 -2
1 -1]
W = sqrt(K) .* J
Q = J.^2
return RNN(; W=W, Q=Q, kwargs...)
end
function f!(du::AbstractVector{Float64}, u, rnn::RNN, t)
f!(du, u, rnn, t, rnn.x, rnn.den)
end
function f!(du, u, rnn::RNN, t)
f!(du, u, rnn, t, similar(du), similar(du))
end
# ...for auto-diff
function f!(du, u, rnn::RNN, t, x, den)
@unpack W, Q, θ, γ, τ = rnn
m = du # reuse memory
@. m = max(0, u)
# x = W * m - θ = √K J m - θ
A_mul_B!(x, W, m)
x .-= θ
# den = √(J.^2 * m + γ)
A_mul_B!(den, Q, m)
@. den = sqrt(den + γ)
@. du = (- u + R(x / den)) / τ
@. du = zero_du_if_u_negative(du, u)
end
@inline zero_du_if_u_negative(du, u) = u < 0 ? max(0, du) : du
# http://docs.juliadiffeq.org/latest/features/callback_library.html#PositiveDomain-1
function make_ode(;
callback = PositiveDomain(abstol=1e-4),
tspan = (0.0, 100.0),
kwargs...)
rnn = make_rnn(; kwargs...)
u0 = [0.1, 0.1]
return ODEProblem(f!, u0, tspan, rnn;
callback = callback)
end
end
using DifferentialEquations
ode = RNNODE.make_ode()
# ode = RNNODE.make_ode(callback=nothing)
# integrator = init(ode, AutoTsit5(Rosenbrock23()))
# integrator = init(ode, Rosenbrock23())
integrator = init(ode, Rodas4())
@time solve!(integrator)
sol = integrator.sol
@assert all(all(u .>= 0) for u in sol.u)
sol.retcode
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment