Last active
August 1, 2019 04:35
-
-
Save sharanry/a4b251c59de4f12cf68d81dea8721ec4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Tracker, Bijectors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"PlanarLayer([-1.38262; 0.107765] (tracked), [2.61178; -0.740486] (tracked), [0.659587; -0.588327] (tracked), [0.910486] (tracked))" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"flow = PlanarLayer(2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×100 Array{Float64,2}:\n", | |
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n", | |
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"z = randn(2, 100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"inv (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Roots, LinearAlgebra\n", | |
"function inv(flow::PlanarLayer, y)\n", | |
" function f(y) \n", | |
" return loss(alpha) = (transpose(flow.w.data)*y)[1] - alpha -(transpose(flow.w.data)*flow.u_hat.data)[1]*tanh(alpha+flow.b.data[1]) \n", | |
" end\n", | |
" alphas = transpose([find_zero(f(y[:,i:i]), randn(), Order16()) for i in 1:size(y)[2]])\n", | |
" z_para = (flow.w.data ./norm(flow.w.data,2))*alphas\n", | |
" z_per = y - z_para - flow.u_hat.data*tanh.(transpose(flow.w.data)*z_para .+ flow.b.data)\n", | |
" \n", | |
" return z_para + z_per\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×100 Array{Float64,2}:\n", | |
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n", | |
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"z" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×100 Array{Float64,2}:\n", | |
" -0.727291 -0.785757 -0.142205 1.79629 … -1.27419 -0.997874 0.381794\n", | |
" -1.43388 -0.139674 -0.613618 0.206456 0.236856 1.44647 0.581317" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv(flow, transform(flow, z).data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"RadialLayer([0.704708] (tracked), [-0.340588] (tracked), [0.829187; 0.840773] (tracked))" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"flow2 = RadialLayer(2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"abs(y - flow2.z_not.data) - r * (1 + flow2.β.data / (flow2.α.data + r)) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"inv (generic function with 2 methods)" | |
] | |
}, | |
"execution_count": 60, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Roots, LinearAlgebra\n", | |
"using StatsFuns: softplus\n", | |
"function inv(flow::RadialLayer, y)\n", | |
" α = softplus(flow.α_.data[1])\n", | |
" β_hat = -α + softplus(flow.β.data[1])\n", | |
" function f(y) \n", | |
" return loss(r) = norm(y - flow.z_not.data, 2) - r * (1 + β_hat / (α + r)) \n", | |
" end\n", | |
" rs = transpose([find_zero(f(y[:,i:i]), 0.0, Order16()) for i in 1:size(y)[2]])\n", | |
"# print(rs)\n", | |
" z = (y.-flow.z_not.data) .* (rs .* (1 .+ β_hat ./ (α .+ rs)) )\n", | |
" return z\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×100 Array{Float64,2}:\n", | |
" -0.708806 -0.77163 -0.130857 1.74092 … -1.26968 -0.990943 0.309014\n", | |
" -1.45037 -0.152275 -0.62374 0.255844 0.232834 1.44029 0.646235" | |
] | |
}, | |
"execution_count": 61, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"z" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"2×100 Array{Float64,2}:\n", | |
" -2.72452 -1.84197 -0.929259 … -3.06125 -2.22814 -0.120209 \n", | |
" -4.30542 -1.0546 -1.52326 -0.720227 0.607145 -0.0389007" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv(flow2, transform(flow2, z).data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"@webio": { | |
"lastCommId": null, | |
"lastKernelId": null | |
}, | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Does changing
Order16
to other order level improve things?
Unfortunately no. Since both inv
and logabsdetjac
have problems for radial flows, there might be a problem with transformation implementation itself.
In case of radial flows, where the inverse in completely off, one main deviation I had made from the paper, was the usage of softplus
to accommodate the constrained nature of the \alpha parameter. Could this be the reason? Is there a better way to accomplish non-negative constraint?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Does changing
Order16
to other order level improve things?