Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active June 22, 2019 11:15
Show Gist options
  • Save sharanry/cf832d8a13bdb600e923e113cd25d54a to your computer and use it in GitHub Desktop.
Save sharanry/cf832d8a13bdb600e923e113cd25d54a to your computer and use it in GitHub Desktop.
using Flux, Flux.Tracker
using Flux.Tracker: grad, update!
using LinearAlgebra
using Distributions
function functional(i)
u, w, b = uₖ[i], wₖ[i],bₖ[i]
f(z) = z + u*tanh.(transpose(w)*z + b)
end
dtanh(z) = 1-tanh.(z).^2
# func = functional([1,0,1], [0,1,0],1)
ψ(z, w, b) = dtanh(transpose(w)*z + b)*w
function getJacobianF(u::Array, w::Array, b::Any)
f(z::Array) = abs(1 + transpose(u)*ψ(z, w, b))
end
function getzₖ(fs, j, z)
temp=z
i=1
while i <= j
i
temp
temp = fs[i](temp)
i+=1
end
temp
end
wₖ = [param(rand(10).+1) for i in 1:10]
uₖ = [param((rand(10).+1)) for i in 1:10]
bₖ = param(rand(10).+1)
fₖ = [f = functional(i) for i in 1:10]
z = [1. for i in 1:10]
zₖ = [getzₖ(fₖ, j, z) for j in 1:10 ]
function 𝑭(x, y, t::Int, K::Int, p, z₀, q₀, fₖ, uₖ, wₖ, bₖ)
zₖ = [getzₖ(fₖ, j, z) for j in 1:10 ]
result = 0
# FIXME: Considering fixed t for simplicity
# βₜ = min(1, 0.01 + 10000t)
βₜ = 0.1
# FIXME: How to calculate this quantity from equation (2) of the paper.
result += log1p.(q₀(z₀)) - βₜ .* log1p.(p(x, y))
# @info x, y
result -= log1p(1 + transpose(uₖ[1])*ψ(z₀, wₖ[1], bₖ[1]))
k=2
while k<=K
result -= log1p(1 + transpose(uₖ[k])*ψ(zₖ[k-1], wₖ[k], bₖ[k]))
k+=1
end
@info result
return result
end
sig = Vector{Float64}([i for i in 1:10])
mu = [0 for i in 1:10]
norm = MvNormal(mu, sig)
K=10
q(x) = pdf(norm, x)
# test energy function as mentioned in the paper
p(x, y) = ((y.-sin.(2π*x/4))./0.4).^2 ./2
𝑭(x, y) = 𝑭(x, y, 1, K, p, z, q, fₖ, uₖ, wₖ, bₖ)
# @info (vcat(uₖ,wₖ,[bₖ]))
@info 𝑭(1, 0, 1, K, p, z, q, fₖ, uₖ, wₖ, bₖ)
θ = Params(vcat(uₖ,wₖ,[bₖ]))
# FIXME: How should I generate data?
x = y = -5:0.1:5
opt = Descent(0.1) # Gradient descent with learning rate 0.1
i=1
println("Starting!...")
while i<=10
global i
grads = Tracker.gradient(() -> -𝑭(x, y), θ)
for p in vcat(uₖ,wₖ,[bₖ])
update!(opt, p, grads[p])
end
i+=1
println(uₖ[1][1])
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment