Last active
July 9, 2019 22:04
-
-
Save sharanry/a30f63b4a541d812ba1c30988a167d49 to your computer and use it in GitHub Desktop.
Norm Flows
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux, Flux.Tracker | |
using Flux.Tracker: grad, update! | |
using Distributions | |
using StatsFuns | |
using MLToolkit | |
using LinearAlgebra | |
using Random | |
abstract type Flow end | |
struct PlanarFlow <: Flow | |
wₖ | |
uₖ | |
u_hat | |
bₖ | |
mu | |
sig | |
depth | |
end | |
function update_u_hat(uₖ, wₖ) | |
# to preserve invertibility | |
u_hat = [param([1,1]) for i in 1:length(uₖ)] | |
for i in 1:length(uₖ) | |
u_hat[i] = uₖ[i] + (m(transpose(wₖ[i])*uₖ[i]) - transpose(wₖ[i])*uₖ[i])*wₖ[i]/(norm(wₖ[i],2)^2) | |
end | |
u_hat | |
end | |
function update_u_hat!(flow::PlanarFlow) | |
for i in 1:flow.depth | |
flow.u_hat[i] = flow.uₖ[i] + (m(transpose(flow.wₖ[i])*flow.uₖ[i]) - transpose(flow.wₖ[i])*flow.uₖ[i])*flow.wₖ[i]/(norm(flow.wₖ[i],2)^2) | |
end | |
end | |
function PlanarFlow(dims::Int, depth::Int) | |
wₖ = [param(randn(dims)) for i in 1:depth] | |
uₖ = [param(randn(dims)) for i in 1:depth] | |
bₖ = param(randn(depth)) | |
u_hat = update_u_hat(uₖ, wₖ) | |
mu = param([0.0 for i in 1:dims]) | |
sig = param([1.0]) | |
return PlanarFlow(wₖ, uₖ, u_hat, bₖ, mu, sig, depth) | |
end | |
# PlanarFlow(wₖ, uₖ, bₖ) = PlanarFlow(wₖ, uₖ, update_u_hat(uₖ, wₖ), bₖ, length(wₖ)) | |
# function getzₖ(fs, j, z) | |
# zₖ[i] | |
# end | |
function planar_f(i, flow::PlanarFlow) | |
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i] | |
f(z) = z + u*tanh.(transpose(w)*z + b) | |
end | |
m(x) = -1 + log(1+exp(x)) | |
dtanh(x) = 1 - tanh.(x)^2 | |
ψ(z, w, b) = dtanh(transpose(w)*z + b)*w | |
function transform_with_logdetj(z, flow::PlanarFlow) | |
update_u_hat!(flow) | |
# compute log_det_jacobian | |
log_det_jacobian = 0 | |
prev = z | |
for i in 1:flow.depth | |
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i] | |
prev = planar_f(i, flow)(prev) | |
psi = ψ(prev, w, b) | |
log_det_jacobian += log.(abs.(1 .+ transpose(psi)*u)) | |
end | |
return prev, log_det_jacobian | |
end | |
function transform(z, flow::PlanarFlow) | |
update_u_hat!(flow) | |
prev = z | |
for i in 1:flow.depth | |
prev = planar_f(i, flow)(prev) | |
end | |
prev | |
end | |
function logdetj(z, flow::PlanarFlow) | |
update_u_hat!(flow) | |
# compute log_det_jacobian | |
log_det_jacobian = 0 | |
prev = z | |
for i in 1:flow.depth | |
u, w, b = flow.u_hat[i], flow.wₖ[i],flow.bₖ[i] | |
prev = planar_f(i, flow)(prev) | |
psi = ψ(prev, w, b) | |
log_det_jacobian += log.(abs.(1 .+ transpose(psi)*u)) | |
end | |
return log_det_jacobian | |
end | |
function likelihood(data, flow::PlanarFlow) | |
mvn = BatchNormal(flow.mu, flow.sig[1]) | |
ll = 0.0 | |
for i in data | |
transformed, log_det_jacobian = transform_with_logdetj(i, flow) | |
ll += sum(logpdf(mvn, transformed)) - log_det_jacobian | |
end | |
return ll | |
end | |
function train!(data, flow, opt, losses; epochs=150) | |
Random.shuffle!(data) | |
# println(flow.wₖ, flow.uₖ, flow.bₖ, flow.mu, flow.sig) | |
# println(likelihood(data, flow)) | |
for i in 1:epochs | |
for j in 1:200:length(data) | |
θ = Params(vcat(flow.uₖ,flow.wₖ,[flow.bₖ], [flow.mu], [flow.sig])) | |
grads = Tracker.gradient(() -> likelihood(data[j:(j+199)], flow), θ) | |
for p in vcat(flow.uₖ, flow.wₖ,[flow.bₖ], [flow.mu], [flow.sig]) | |
update!(opt, p, grads[p]) | |
end | |
end | |
print('.') | |
loss = likelihood(data, flow) | |
append!(losses, loss) | |
if i%10==0 | |
print(i, '\n') | |
println(loss) | |
end | |
end | |
# return wₖ, uₖ, bₖ, mu, sig | |
end | |
function inv_transform(z, flow::PlanarFlow) | |
# TODO: Implement | |
0 | |
end | |
# initialize flow with dim 2 and depth 10 | |
flow = PlanarFlow(2, 10) | |
# transform given 2-dim data point | |
# transform([1,1], flow) | |
mvn1 = MvNormal([0.0,10.0], 1.0) | |
mvn2 = MvNormal([0.0,-10.0], 1.0) | |
data = vcat([rand(mvn1) for i in 1:200] , [rand(mvn2) for i in 1:200]) | |
losses = [] | |
opt = ADAM(0.1, (0.9, 0.999)) | |
train!(data, flow, opt, losses; epochs=50) | |
# @info flow.mu | |
# Plotting Result | |
using StatsBase, Gadfly, Plots, StatPlots | |
begin | |
data_test = [ [i, j] for i=-15:0.1:15 for j=-15:0.1:15]; | |
X_test = [data_test[i][1] for i in 1:length(data_test)] | |
Y_test = [data_test[i][2] for i in 1:length(data_test)] | |
P_test = [exp.(likelihood([data_test[i]], flow).data) for i in 1:length(data_test)] | |
P_test = P_test ./ (maximum(P_test) - minimum(P_test)) | |
k = StatsBase.sample(1:length(P_test), Weights(P_test),10000); | |
p2 = marginalhist(X_test[k], Y_test[k], bins=100; xlims=[-15, 15], ylims=[-15, 15]) | |
end | |
savefig(p2, "0.1lr_20epochs") |
Yes. I am implementing the inverse. For inverse to exist, we also need a "maintainance" step for uk. This is given in detail in the appendix of https://arxiv.org/pdf/1505.05770.pdf
Yes you need to make sure the inverse exists.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Good job. It's towards the right direction. Suggestions:
forward
andinverse
that does z -> x and x -> zOf course in planar flow, this term is 0. But for the interface purpose we should have it.