Last active
March 21, 2024 15:50
-
-
Save vankesteren/33f2f9b077642758232c515bdf4b8862 to your computer and use it in GitHub Desktop.
Figuring out some ELBO stuff...
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Let's figure out this ELBO thing | |
using Distributions, StatsPlots, Optim, Random | |
Random.seed!(45) | |
# The target distribution. Assume we don't know it but | |
# we can compute the (unnormalized) logpdf and sample | |
# from it. For illustration, let's make it a weird mixture | |
comps = [Normal(2, 3), Normal(-3, 1.5), LogNormal(3, 0.4)] | |
probs = [.1, .1, .8] | |
p = MixtureModel(comps, probs) | |
# Let's take a look! | |
histogram(rand(p, 100_000), normalize = true, label = "p") | |
# A function for computing the ELBO from two distributions, using | |
# Monte Carlo integration as in equation 7 of the following paper: | |
# https://www.jmlr.org/papers/volume23/21-0889/21-0889.pdf | |
function ELBO(p::Distribution, q::Distribution; K::Int = 1000) | |
x = rand(q, K) # sample from q for MC integration | |
return mean(logpdf.(p, x) - logpdf.(q, x)) | |
end | |
# We can compute the ELBO for a particular distribution | |
ELBO(p, Normal(0, 1)) | |
ELBO(p, Normal(10, 8)) | |
# we can see that Normal(10, 8) is better, as it has higher ELBO | |
# now define a loss function to minimize | |
loss = θ -> -ELBO(p, Normal(θ[1], exp(θ[2]))) | |
res = optimize(loss, ones(2), SimulatedAnnealing()) | |
θ = res.minimizer | |
q = Normal(θ[1], exp(θ[2])) | |
# Now let's plot! | |
plot!(q, label = "q (ELBO)", color = "green") |
Author
vankesteren
commented
Mar 21, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment