Last active
April 17, 2019 23:20
-
-
Save whilo/a37e587b54457cae3ac80d4926932ee6 to your computer and use it in GitHub Desktop.
CNF playground with nested Jacobian not working.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using DifferentialEquations | |
using Distributions | |
using Flux, DiffEqFlux | |
using Flux.Tracker | |
function f(z, p) | |
α, β = p | |
tanh.(α.*z .+ β) | |
end | |
# patch broken jacobian from tracker | |
function jacobian2(m,xp) | |
#xp = param(x) | |
x = [xp.data for xp in xp] | |
y = m(xp) | |
k = length(y) | |
n = length(x) | |
J = Matrix{eltype(x)}(undef,k,n) | |
for i = 1:k | |
back!(y[i], once = false) # Populate gradient accumulator | |
J[i,:] = xp.tracker.grad | |
#xp.tracker.grad .= 0 # Reset gradient accumulator | |
end | |
J | |
end | |
function cnf(du,u,p,t) | |
z, logpz = u | |
α, β = p | |
du[1] = f(z, p) | |
#du[2] = -sum(jacobian2((z)->f(z, p), [z])) | |
du[2] = -(1-tanh(α*z + β)^2)*α # manual | |
end | |
function predict_rd(x) | |
u0 = [x, 0.0] | |
tspan = (0.0, 10.0) | |
prob = ODEProblem(cnf,u0,tspan,p) | |
diffeq_rd(p,prob,Tsit5(),saveat=0.1) | |
end | |
function loss_rd(xs) | |
pz = Normal(0.0, 1.0) | |
preds = [predict_rd(x)[:,end] for x in xs] | |
z = [pred[1] for pred in preds] # TODO better slicing | |
delta_logp = [pred[2] for pred in preds] | |
logpz = logpdf.(pz, z) | |
logpx = logpz - delta_logp | |
loss = -mean(logpx) | |
end | |
opt = ADAM(0.1) | |
raw_data = [[rand(Normal(2.0, 0.1)) for i in 1:100]] | |
data = Iterators.repeated(raw_data, 100); | |
p = param([0.0, 0.0]) # Initial Parameter Vector | |
params = Params([p]) | |
Flux.train!(loss_rd, params, data, opt) | |
# check whether it looks standard normal | |
using Plots | |
preds = [predict_rd(r)[:,end] for r in raw_data[1]]; | |
histogram([p[1].data for p in preds]) | |
# plot traces of flow | |
trajs = [predict_rd(raw_data[1][i]) for i in 1:100] | |
plot(trajs[1].t, [[u[1].data for u in traj.u] for traj in trajs]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment