Skip to content

Instantly share code, notes, and snippets.

@dermesser
Last active September 4, 2022 18:00
Show Gist options
  • Select an option

  • Save dermesser/297be1597bef469ca2ebaf3508c71ba2 to your computer and use it in GitHub Desktop.

Select an option

Save dermesser/297be1597bef469ca2ebaf3508c71ba2 to your computer and use it in GitHub Desktop.
Super simple MNIST Variational Autoencoder
using MKL
using BSON
using Distributions, DistributionsAD
using Flux
using Images
using LinearAlgebra
using MLDatasets
import TensorBoardLogger as TBL
using Logging
import Flux.Losses: kldivergence
import Flux.MLUtils: DataLoader
function load_mnist_dataset(kind=:train)::Tuple{AbstractArray, Vector}
if kind == :train
MLDatasets.MNIST.traindata()
elseif kind == :test
MLDatasets.MNIST.testdata()
end
end
struct VAE
encoder_mean
encoder_cov
decoder_mean
end
function encoder_cov_activate(x)
Flux.relu(x)^2 + 1e-2
end
function save(name::String, vae::VAE)
BSON.@save "vae.bson" vae
end
function load(v::Type{VAE}, name::String)::VAE
BSON.@load "vae.bson" vae
vae
end
function prepare_data(X)
Flux.unsqueeze(X, 3)
end
function restore_image(decoder_output)
reshape(decoder_output, 28, 28, size(decoder_output, 2))
end
function create_networks(X; latent_dim=8)
indims = size(X)
encoder_means = Chain(
Flux.flatten,
Dense(784 => 64, Flux.σ),
Dense(64 => 16, Flux.σ),
Dense(16 => latent_dim))
encoder_cov = Chain(
Flux.flatten,
Dense(784 => 64, Flux.σ),
Dense(64 => 16, Flux.σ),
Dense(16 => latent_dim, encoder_cov_activate))
decoder_means = Chain(
Dense(latent_dim => 16),
Dense(16 => 64),
Dense(64 => indims[1]*indims[2], Flux.σ),
restore_image
)
VAE(encoder_means, encoder_cov, decoder_means)
end
function create_conv_networks(X; latent_dim=8)
indims = size(X)
encoder_means = Chain(
prepare_data,
Conv((4,4), 1 => 1),
Flux.flatten,
Dense(25^2 => 16, Flux.σ),
Dense(16 => latent_dim))
encoder_cov = Chain(
prepare_data,
Conv((4,4), 1 => 1),
Flux.flatten,
Dense(25^2 => 16, Flux.σ),
Dense(16 => latent_dim, encoder_cov_activate))
decoder_means = Chain(
Dense(latent_dim => 16),
Dense(16 => 64),
Dense(64 => indims[1]*indims[2], Flux.σ),
restore_image
)
VAE(encoder_means, encoder_cov, decoder_means)
end
function normal_kld(µ1, Σ1, µ2, Σ2)
@assert size(µ1) == size(µ2)
inv_Σ2 = inv(Σ2)
d1 = det(Σ1)
d2 = det(Σ2)
k = length(µ1)
1/2 * (tr(inv_Σ2 * Σ1) - k + (µ2-µ1)' * inv_Σ2 * (µ2-µ1) + log(d2/d1))
end
function vae_loss(X, vae, ζs, c::Real)
means = vae.encoder_mean(X)
covs = vae.encoder_cov(X)
z = (ζs .* covs) .+ means
k = size(means, 1)
sum((X .- vae.decoder_mean(z)).^2)./(2*c) + sum(
Distributions.kldivergence(
MvNormal(m, diagm(co)),
MvNormal(zeros(k), I))
for (m,co) in zip(eachcol(means), eachcol(covs)))
end
function fill_param_dict!(dict, m, prefix)
if m isa Chain
for (i, layer) in enumerate(m.layers)
fill_param_dict!(dict, layer, prefix*"layer_"*string(i)*"/"*string(layer)*"/")
end
else
for fieldname in fieldnames(typeof(m))
val = getfield(m, fieldname)
if val isa AbstractArray
val = vec(val)
end
dict[prefix*string(fieldname)] = val
end
end
end
function train(;previous_vae=nothing, epochs=100, lr=1e-3, latent_dim=8, c=.1, batchsize=8, shuffle=true)
logger = TBL.TBLogger("log", TBL.tb_overwrite)
X, _ = load_mnist_dataset(:train)
valX, _ = load_mnist_dataset(:test)
ntest = batchsize
testX = @view valX[:,:,1:ntest]
if isnothing(previous_vae)
vae = create_networks(X; latent_dim=latent_dim)
else
vae = previous_vae
end
# Approx. reparam_z
reparam_ζ = MvNormal(zeros(latent_dim), LinearAlgebra.I)
loss(X, ζs) = begin
vae_loss(X, vae, ζs, c)
end
data = DataLoader(X; batchsize=batchsize, shuffle=shuffle)
opt = Flux.Optimise.Adam(lr)
p = Flux.params(vae.encoder_mean, vae.encoder_cov, vae.decoder_mean)
i = 0
tbcallback() = begin
paramd = Dict{String, Any}()
fill_param_dict!(paramd, vae.encoder_mean, "enc_mean")
fill_param_dict!(paramd, vae.encoder_cov, "enc_cov")
fill_param_dict!(paramd, vae.decoder_mean, "dec_mean")
with_logger(logger) do
@info "model" params=paramd log_step_increment=0
@info "train" loss=loss((@view X[:,:,1:5ntest]), rand(reparam_ζ, 5ntest))
@info "test" loss=loss((@view valX[:,:,1:5ntest]), rand(reparam_ζ, 5ntest))
end
end
for epoch in 1:epochs
i += 1;
println("\n epoch $i: loss $(loss(testX, rand(reparam_ζ, ntest)))")
print(" > ")
j = 0
for d in data
if j % 100 == 0
print("$j ")
end
if j % 1000 == 0
print("(loss: $(loss(testX, rand(reparam_ζ, ntest)))) ")
end
if j % 10000 == 0
tbcallback()
end
j += batchsize;
# Reparameterization trick
ζs = rand(reparam_ζ, batchsize)
grads = Flux.gradient(() -> loss(d, ζs), p)
Flux.Optimise.update!(opt, p, grads)
end
end
vae
end
function sample_image(inspiration, vae; n=8)
inspiration = Flux.unsqueeze(inspiration, 3)
means = vae.encoder_mean(inspiration)
covs = vae.encoder_cov(inspiration)
dist = MvNormal(means[:,1], diagm(covs[:,1]))
Images.save("inspiration.png", Gray{N0f8}.(inspiration))
for i in 1:n
z = rand(dist, 1)
Images.save("test$(i).png", Gray{N0f8}.(reshape(vae.decoder_mean(z), 28, 28)))
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment