Skip to content

Instantly share code, notes, and snippets.

@sharanry
Last active July 9, 2019 22:04
Show Gist options
  • Save sharanry/a30f63b4a541d812ba1c30988a167d49 to your computer and use it in GitHub Desktop.
Save sharanry/a30f63b4a541d812ba1c30988a167d49 to your computer and use it in GitHub Desktop.
Norm Flows
using Flux, Flux.Tracker
using Flux.Tracker: grad, update!
using Distributions
using StatsFuns
using MLToolkit
using LinearAlgebra
using Random
abstract type Flow end
struct PlanarFlow <: Flow
wₖ
uₖ
u_hat
bₖ
mu
sig
depth
end
function update_u_hat(uₖ, wₖ)
# to preserve invertibility
u_hat = [param([1,1]) for i in 1:length(uₖ)]
for i in 1:length(uₖ)
u_hat[i] = uₖ[i] + (m(transpose(wₖ[i])*uₖ[i]) - transpose(wₖ[i])*uₖ[i])*wₖ[i]/(norm(wₖ[i],2)^2)
end
u_hat
end
function update_u_hat!(flow::PlanarFlow)
for i in 1:flow.depth
flow.u_hat[i] = flow.uₖ[i] + (m(transpose(flow.wₖ[i])*flow.uₖ[i]) - transpose(flow.wₖ[i])*flow.uₖ[i])*flow.wₖ[i]/(norm(flow.wₖ[i],2)^2)
end
end
function PlanarFlow(dims::Int, depth::Int)
wₖ = [param(randn(dims)) for i in 1:depth]
uₖ = [param(randn(dims)) for i in 1:depth]
bₖ = param(randn(depth))
u_hat = update_u_hat(uₖ, wₖ)
mu = param([0.0 for i in 1:dims])
sig = param([1.0])
return PlanarFlow(wₖ, uₖ, u_hat, bₖ, mu, sig, depth)
end
# PlanarFlow(wₖ, uₖ, bₖ) = PlanarFlow(wₖ, uₖ, update_u_hat(uₖ, wₖ), bₖ, length(wₖ))
# function getzₖ(fs, j, z)
# zₖ[i]
# end
function planar_f(i, flow::PlanarFlow)
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i]
f(z) = z + u*tanh.(transpose(w)*z + b)
end
m(x) = -1 + log(1+exp(x))
dtanh(x) = 1 - tanh.(x)^2
ψ(z, w, b) = dtanh(transpose(w)*z + b)*w
function transform_with_logdetj(z, flow::PlanarFlow)
update_u_hat!(flow)
# compute log_det_jacobian
log_det_jacobian = 0
prev = z
for i in 1:flow.depth
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i]
prev = planar_f(i, flow)(prev)
psi = ψ(prev, w, b)
log_det_jacobian += log.(abs.(1 .+ transpose(psi)*u))
end
return prev, log_det_jacobian
end
function transform(z, flow::PlanarFlow)
update_u_hat!(flow)
prev = z
for i in 1:flow.depth
prev = planar_f(i, flow)(prev)
end
prev
end
function logdetj(z, flow::PlanarFlow)
update_u_hat!(flow)
# compute log_det_jacobian
log_det_jacobian = 0
prev = z
for i in 1:flow.depth
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i]
prev = planar_f(i, flow)(prev)
psi = ψ(prev, w, b)
log_det_jacobian += log.(abs.(1 .+ transpose(psi)*u))
end
return log_det_jacobian
end
function likelihood(data, flow::PlanarFlow)
mvn = BatchNormal(flow.mu, flow.sig[1])
ll = 0.0
for i in data
transformed, log_det_jacobian = transform_with_logdetj(i, flow)
ll += sum(logpdf(mvn, transformed)) - log_det_jacobian
end
return ll
end
function train!(data, flow, opt, losses; epochs=150)
Random.shuffle!(data)
# println(flow.wₖ, flow.uₖ, flow.bₖ, flow.mu, flow.sig)
# println(likelihood(data, flow))
for i in 1:epochs
for j in 1:200:length(data)
θ = Params(vcat(flow.uₖ,flow.wₖ,[flow.bₖ], [flow.mu], [flow.sig]))
grads = Tracker.gradient(() -> likelihood(data[j:(j+199)], flow), θ)
for p in vcat(flow.uₖ, flow.wₖ,[flow.bₖ], [flow.mu], [flow.sig])
update!(opt, p, grads[p])
end
end
print('.')
loss = likelihood(data, flow)
append!(losses, loss)
if i%10==0
print(i, '\n')
println(loss)
end
end
# return wₖ, uₖ, bₖ, mu, sig
end
function inv_transform(z, flow::PlanarFlow)
# TODO: Implement
0
end
# initialize flow with dim 2 and depth 10
flow = PlanarFlow(2, 10)
# transform given 2-dim data point
# transform([1,1], flow)
mvn1 = MvNormal([0.0,10.0], 1.0)
mvn2 = MvNormal([0.0,-10.0], 1.0)
data = vcat([rand(mvn1) for i in 1:200] , [rand(mvn2) for i in 1:200])
losses = []
opt = ADAM(0.1, (0.9, 0.999))
train!(data, flow, opt, losses; epochs=50)
# @info flow.mu
# Plotting Result
using StatsBase, Gadfly, Plots, StatPlots
begin
data_test = [ [i, j] for i=-15:0.1:15 for j=-15:0.1:15];
X_test = [data_test[i][1] for i in 1:length(data_test)]
Y_test = [data_test[i][2] for i in 1:length(data_test)]
P_test = [exp.(likelihood([data_test[i]], flow).data) for i in 1:length(data_test)]
P_test = P_test ./ (maximum(P_test) - minimum(P_test))
k = StatsBase.sample(1:length(P_test), Weights(P_test),10000);
p2 = marginalhist(X_test[k], Y_test[k], bins=100; xlims=[-15, 15], ylims=[-15, 15])
end
savefig(p2, "0.1lr_20epochs")
@xukai92
Copy link

xukai92 commented Jul 3, 2019

Good job. It's towards the right direction. Suggestions:

  • We need two transformations right? So called forward and inverse that does z -> x and x -> z
  • For both these functions, in general, we need to return both the transformed variable as well as the log determinant of the inv. Jacobian. Of course in planar flow, this term is 0. But for the interface purpose we should have it.

@sharanry
Copy link
Author

sharanry commented Jul 3, 2019

Yes. I am implementing the inverse. For inverse to exist, we also need a "maintainance" step for uk. This is given in detail in the appendix of https://arxiv.org/pdf/1505.05770.pdf

@xukai92
Copy link

xukai92 commented Jul 3, 2019

Yes you need to make sure the inverse exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment