Last active
August 29, 2019 08:50
-
-
Save sharanry/2d4e824f27ed01e4f7a1d6fb6f02305a to your computer and use it in GitHub Desktop.
Simple NFVI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 320, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Debugger" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 321, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Bijectors\n", | |
"using Distributions\n", | |
"using Turing\n", | |
"using TrackedDistributions\n", | |
"using ForwardDiff\n", | |
"using Random\n", | |
"using Tracker\n", | |
"using Flux\n", | |
"using Distances" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 335, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@model simple() = begin\n", | |
" a ~ Normal(10, 5) \n", | |
" b ~ Normal(a, 5) \n", | |
" return a, b\n", | |
"end\n", | |
"model = simple();\n", | |
"# @model gdemo_d() = begin\n", | |
"# s ~ InverseGamma(2, 3)\n", | |
"# m ~ Normal(0, sqrt(s))\n", | |
"# 1.5 ~ Normal(m, sqrt(s))\n", | |
"# 2.0 ~ Normal(m, sqrt(s))\n", | |
"# return s, m\n", | |
"# end\n", | |
"# model = gdemo_d();" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 362, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"get_transforms (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 362, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function get_transforms(model::Turing.Model)\n", | |
" varinfo = Turing.VarInfo(model)\n", | |
" num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym ∈ keys(varinfo.metadata)])\n", | |
" \n", | |
" base = MvNormal(zeros(num_params), ones(num_params))\n", | |
" flow = Bijectors.compose(Bijectors.Scale(num_params, param), Bijectors.Shift(num_params, param), [ i%2==1 ? Bijectors.RadialLayer(num_params, param) : Bijectors.PlanarLayer(num_params, param) for i in 1:10]...);\n", | |
" trans_base = transformed(base, flow);\n", | |
" return (base=base, flow=flow, trans_base=trans_base) \n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 363, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"base, flow, trans_base = get_transforms(model);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 364, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"rng = MersenneTwister(1234);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 365, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"6.4086145128384295 (tracked)" | |
] | |
}, | |
"execution_count": 365, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"function F(rng, model, trans_base; num_samples=10, x=nothing, print_info=false)\n", | |
" logpdf_p = []\n", | |
" if x==nothing\n", | |
" x = rand(rng, trans_base.dist, num_samples)\n", | |
" end\n", | |
" \n", | |
" _x, y, logjac, logpdf = forward(trans_base, x)\n", | |
" if print_info\n", | |
" @info \"x\" x\n", | |
" @info \"y\" y\n", | |
" @info \"logjac\" logjac\n", | |
" @info \"logpdf\" logpdf\n", | |
" end\n", | |
" \n", | |
" varinfo = Turing.VarInfo(model)\n", | |
" \n", | |
" for i in 1:size(y, 2)\n", | |
" varinfo_new = Turing.VarInfo(varinfo, Turing.SampleFromUniform(), y[:,i])\n", | |
" model(varinfo_new)\n", | |
" if print_info\n", | |
" @info \"varinfo_new\" varinfo_new.logp\n", | |
" end\n", | |
" append!(logpdf_p, varinfo_new.logp)\n", | |
" end\n", | |
" \n", | |
" logpdf_q = logpdf - logjac# - [x[1,i] for i in 1:size(x, 2)]\n", | |
" if print_info\n", | |
" @info \"logpdf_p\" logpdf_p \n", | |
" @info \"logpdf_q\" logpdf_q\n", | |
" end\n", | |
" logpdf_p = Tracker.collect(logpdf_p)\n", | |
"# mean(logpdf_q - logpdf_p)#, logpdf_q, logpdf_p\n", | |
" kl_divergence(exp.(logpdf_q), exp.(logpdf_p))\n", | |
"end\n", | |
"\n", | |
"# Sanity Check\n", | |
"out = F(rng, model, trans_base, print_info=false, num_samples=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 366, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"Figure(PyObject <Figure size 640x480 with 1 Axes>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# High variability of F. A pontential problem\n", | |
"using PyPlot\n", | |
"PyPlot.hist([F(rng, model, trans_base, num_samples=100).data for i in 1:1000], bins=50);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 367, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"get_ϕ! (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 367, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# To extract all tracked params\n", | |
"function get_ϕ!(flow::Composed, ϕ)\n", | |
" for i in flow.ts\n", | |
" if typeof(i) <: Composed\n", | |
" get_ϕ!(i, ϕ)\n", | |
" else\n", | |
" for j in propertynames(i, false)\n", | |
" append!(ϕ, [getproperty(i,j)])\n", | |
" end\n", | |
" end\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 368, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"32-element Array{Any,1}:\n", | |
" [1.0 0.0; 0.0 1.0] (tracked) \n", | |
" [0.0; 0.0] (tracked) \n", | |
" [-0.517413] (tracked) \n", | |
" [0.718511] (tracked) \n", | |
" [1.77198; 1.4844] (tracked) \n", | |
" [1.55561; -2.24004] (tracked) \n", | |
" [-0.417338; -3.55184] (tracked) \n", | |
" [-1.46857] (tracked) \n", | |
" [-0.427506] (tracked) \n", | |
" [-0.425289] (tracked) \n", | |
" [-0.33839; -0.133261] (tracked) \n", | |
" [-1.87512; -0.371103] (tracked) \n", | |
" [1.91111; 0.645955] (tracked) \n", | |
" ⋮ \n", | |
" [0.987035] (tracked) \n", | |
" [0.842437] (tracked) \n", | |
" [-0.604559; -0.282666] (tracked)\n", | |
" [0.524474; 1.61454] (tracked) \n", | |
" [-1.82578; 2.07189] (tracked) \n", | |
" [0.209868] (tracked) \n", | |
" [-0.257118] (tracked) \n", | |
" [1.20459] (tracked) \n", | |
" [0.279818; 2.293] (tracked) \n", | |
" [0.366493; -1.01509] (tracked) \n", | |
" [0.575756; -1.35863] (tracked) \n", | |
" [-0.210145] (tracked) " | |
] | |
}, | |
"execution_count": 368, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ϕ = []\n", | |
"get_ϕ!(flow, ϕ)\n", | |
"ϕ" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 369, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"┌ Info: 5.687605605915064 (tracked)\n", | |
"└ @ Main In[369]:6\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"..........\n", | |
"(100/10000) done; loss=5.798888461987018\n", | |
"..........\n", | |
"(200/10000) done; loss=5.532302103696383\n", | |
"..........\n", | |
"(300/10000) done; loss=5.40636217290014\n", | |
"..........\n", | |
"(400/10000) done; loss=5.277394195225437\n", | |
"..........\n", | |
"(500/10000) done; loss=5.089418426079885\n", | |
"..........\n", | |
"(600/10000) done; loss=5.062693849123746\n", | |
"..........\n", | |
"(700/10000) done; loss=5.07495216807752\n", | |
"..........\n", | |
"(800/10000) done; loss=4.935018449783298\n", | |
"..........\n", | |
"(900/10000) done; loss=4.838258708632543\n", | |
"..........\n", | |
"(1000/10000) done; loss=4.6412013565065005\n", | |
"..........\n", | |
"(1100/10000) done; loss=4.622866199425547\n", | |
"..........\n", | |
"(1200/10000) done; loss=4.322743314518878\n", | |
"..........\n", | |
"(1300/10000) done; loss=4.480712672678272\n", | |
"..........\n", | |
"(1400/10000) done; loss=4.158097191002094\n", | |
"..........\n", | |
"(1500/10000) done; loss=4.264347791127247\n", | |
"..........\n", | |
"(1600/10000) done; loss=4.173307206543895\n", | |
"..........\n", | |
"(1700/10000) done; loss=3.9286591644533093\n", | |
"..........\n", | |
"(1800/10000) done; loss=4.002175241609857\n", | |
"..........\n", | |
"(1900/10000) done; loss=3.93243579345907\n", | |
"..........\n", | |
"(2000/10000) done; loss=3.886448858843637\n", | |
"..........\n", | |
"(2100/10000) done; loss=3.8873126379114797\n", | |
"..........\n", | |
"(2200/10000) done; loss=3.931958139182028\n", | |
"..........\n", | |
"(2300/10000) done; loss=3.7904366437152492\n", | |
"..........\n", | |
"(2400/10000) done; loss=3.751041661311117\n", | |
"..........\n", | |
"(2500/10000) done; loss=3.716294025864434\n", | |
"..........\n", | |
"(2600/10000) done; loss=3.697364805947052\n", | |
"..........\n", | |
"(2700/10000) done; loss=3.646594445889464\n", | |
"..........\n", | |
"(2800/10000) done; loss=3.575102239215171\n", | |
"..........\n", | |
"(2900/10000) done; loss=3.655719770510173\n", | |
"..........\n", | |
"(3000/10000) done; loss=3.5701629347308574\n", | |
"..........\n", | |
"(3100/10000) done; loss=3.568918757494621\n", | |
"..........\n", | |
"(3200/10000) done; loss=3.6122952747296955\n", | |
"..........\n", | |
"(3300/10000) done; loss=3.384335376480747\n", | |
"..........\n", | |
"(3400/10000) done; loss=3.345501497890491\n", | |
"..........\n", | |
"(3500/10000) done; loss=3.5238478717671415\n", | |
"..........\n", | |
"(3600/10000) done; loss=3.410132750398626\n", | |
"..........\n", | |
"(3700/10000) done; loss=3.3041982877108067\n", | |
"..........\n", | |
"(3800/10000) done; loss=3.2532203455570756\n", | |
"..........\n", | |
"(3900/10000) done; loss=3.3720866060014174\n", | |
"..........\n", | |
"(4000/10000) done; loss=3.2459250924003067\n", | |
"..........\n", | |
"(4100/10000) done; loss=3.214724416423852\n", | |
"..........\n", | |
"(4200/10000) done; loss=3.174708212806242\n", | |
"..........\n", | |
"(4300/10000) done; loss=3.289272111654915\n", | |
"..........\n", | |
"(4400/10000) done; loss=3.250804309497583\n", | |
"..........\n", | |
"(4500/10000) done; loss=3.2941254904724366\n", | |
"..........\n", | |
"(4600/10000) done; loss=3.178987992552885\n", | |
"..........\n", | |
"(4700/10000) done; loss=3.125103800561713\n", | |
"..........\n", | |
"(4800/10000) done; loss=3.228475500384154\n", | |
"..........\n", | |
"(4900/10000) done; loss=3.0838138888630042\n", | |
"..........\n", | |
"(5000/10000) done; loss=3.136007874265254\n", | |
"..........\n", | |
"(5100/10000) done; loss=3.0248921318124626\n", | |
"..........\n", | |
"(5200/10000) done; loss=3.085563156977406\n", | |
"..........\n", | |
"(5300/10000) done; loss=3.0645348100492162\n", | |
"..........\n", | |
"(5400/10000) done; loss=2.958291727211016\n", | |
"..........\n", | |
"(5500/10000) done; loss=2.9304965829886176\n", | |
"..........\n", | |
"(5600/10000) done; loss=2.796315805425349\n", | |
"..........\n", | |
"(5700/10000) done; loss=2.180155893884919\n", | |
"..........\n", | |
"(5800/10000) done; loss=2.0358903633729097\n", | |
"..........\n", | |
"(5900/10000) done; loss=2.144586747322932\n", | |
"..........\n", | |
"(6000/10000) done; loss=2.1564737450051426\n", | |
"..........\n", | |
"(6100/10000) done; loss=2.167814701455386\n", | |
"..........\n", | |
"(6200/10000) done; loss=2.105650026763533\n", | |
"..........\n", | |
"(6300/10000) done; loss=2.2037724907621925\n", | |
"..........\n", | |
"(6400/10000) done; loss=2.0986737473657593\n", | |
"..........\n", | |
"(6500/10000) done; loss=2.139643804825428\n", | |
"..........\n", | |
"(6600/10000) done; loss=2.0869740899137\n", | |
"..........\n", | |
"(6700/10000) done; loss=2.144057399045887\n", | |
"..........\n", | |
"(6800/10000) done; loss=2.140419313082119\n", | |
"..........\n", | |
"(6900/10000) done; loss=2.123319137390585\n", | |
"..........\n", | |
"(7000/10000) done; loss=2.272158844173498\n", | |
"..........\n", | |
"(7100/10000) done; loss=2.077289583190817\n", | |
"..........\n", | |
"(7200/10000) done; loss=2.131396233779441\n", | |
"..........\n", | |
"(7300/10000) done; loss=2.1276211355646013\n", | |
"...." | |
] | |
}, | |
{ | |
"ename": "InterruptException", | |
"evalue": "InterruptException:", | |
"output_type": "error", | |
"traceback": [ | |
"InterruptException:", | |
"", | |
"Stacktrace:", | |
" [1] gradient_(::getfield(Main, Symbol(\"##396#397\")), ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:7", | |
" [2] #gradient#24(::Bool, ::Function, ::Function, ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:164", | |
" [3] gradient(::Function, ::Params) at /home/sharan/.julia/packages/Tracker/SAr25/src/back.jl:164", | |
" [4] top-level scope at In[369]:11", | |
" [5] top-level scope at util.jl:213", | |
" [6] top-level scope at In[369]:7" | |
] | |
} | |
], | |
"source": [ | |
"Phi = Flux.Params(ϕ)\n", | |
"opt = ADAM(2e-4)\n", | |
"niters = 10_000\n", | |
"losses = []\n", | |
"# Initial F\n", | |
"@info F(rng, model, trans_base, num_samples=10)\n", | |
"timeused = @elapsed for iter = 1:niters\n", | |
" iter % 10 == 0 && print(\".\")\n", | |
"\n", | |
" loss = F(rng, model, trans_base, num_samples=10)\n", | |
" gs = Tracker.gradient(() -> loss, Phi)\n", | |
" for p in ϕ\n", | |
" Tracker.update!(opt, p, gs[p])\n", | |
" end\n", | |
" append!(losses, loss.data)\n", | |
" \n", | |
" if iter % 100 == 0\n", | |
" mean_loss = mean(losses[end-99:end])\n", | |
" println(\"\\n($iter/$niters) done; loss=$mean_loss\")\n", | |
" end\n", | |
"end\n", | |
"println(\"done; $(timeused)s used\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 371, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"Figure(PyObject <Figure size 640x480 with 1 Axes>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"1-element Array{PyCall.PyObject,1}:\n", | |
" PyObject <matplotlib.lines.Line2D object at 0x7f648f7ace80>" | |
] | |
}, | |
"execution_count": 371, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Moving Mean Loss Plot\n", | |
"using PyPlot\n", | |
"PyPlot.plot([mean(losses[i-99:i]) for i in 100:7347])\n", | |
"# PyPlot.plot(losses)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 372, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"32-element Array{TrackedArray{Float64,N,A} where A<:AbstractArray{Float64,N} where N,1}:\n", | |
" [-0.00539625 0.00162034; 0.00945505 -0.000361126] (tracked)\n", | |
" [-0.00731297; 0.00805593] (tracked) \n", | |
" [0.00160014] (tracked) \n", | |
" [-0.00177673] (tracked) \n", | |
" [0.000831536; -0.000520562] (tracked) \n", | |
" [-0.00307323; -0.00107566] (tracked) \n", | |
" [0.00131467; -0.00104153] (tracked) \n", | |
" [-0.00300304] (tracked) \n", | |
" [-0.00100636] (tracked) \n", | |
" [0.000581572] (tracked) \n", | |
" [-0.000491236; 2.52041e-5] (tracked) \n", | |
" [0.000513647; -0.00079435] (tracked) \n", | |
" [0.00075561; -0.00257502] (tracked) \n", | |
" ⋮ \n", | |
" [0.0001627] (tracked) \n", | |
" [-0.000169308] (tracked) \n", | |
" [-0.000202415; 0.000218317] (tracked) \n", | |
" [0.00332348; 0.000783873] (tracked) \n", | |
" [0.00102101; 0.000384939] (tracked) \n", | |
" [0.00212457] (tracked) \n", | |
" [0.000564511] (tracked) \n", | |
" [-0.000680569] (tracked) \n", | |
" [7.86032e-5; 5.91268e-5] (tracked) \n", | |
" [0.339627; -1.19814] (tracked) \n", | |
" [-0.000927946; 0.000120444] (tracked) \n", | |
" [-0.0133198] (tracked) " | |
] | |
}, | |
"execution_count": 372, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss = F(rng, model, trans_base, num_samples=10)\n", | |
"gs = Tracker.gradient(() -> loss, Phi)\n", | |
"[gs[ϕ[i]] for i in 1:length(ϕ)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 373, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9.890143206195424" | |
] | |
}, | |
"execution_count": 373, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mean((flow(rand(rng, base, 10000)))[1,:].data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 374, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9.979383902117426" | |
] | |
}, | |
"execution_count": 374, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mean((flow(rand(rng, base, 10000)))[2,:].data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 375, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"Figure(PyObject <Figure size 640x480 with 1 Axes>)" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"PyPlot.hist(flow(rand(rng, base, 10000))[1,:].data);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"@webio": { | |
"lastCommId": null, | |
"lastKernelId": null | |
}, | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment