Last active
September 4, 2022 18:00
-
-
Save dermesser/297be1597bef469ca2ebaf3508c71ba2 to your computer and use it in GitHub Desktop.
Super simple MNIST Variational Autoencoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using MKL | |
| using BSON | |
| using Distributions, DistributionsAD | |
| using Flux | |
| using Images | |
| using LinearAlgebra | |
| using MLDatasets | |
| import TensorBoardLogger as TBL | |
| using Logging | |
| import Flux.Losses: kldivergence | |
| import Flux.MLUtils: DataLoader | |
| function load_mnist_dataset(kind=:train)::Tuple{AbstractArray, Vector} | |
| if kind == :train | |
| MLDatasets.MNIST.traindata() | |
| elseif kind == :test | |
| MLDatasets.MNIST.testdata() | |
| end | |
| end | |
| struct VAE | |
| encoder_mean | |
| encoder_cov | |
| decoder_mean | |
| end | |
| function encoder_cov_activate(x) | |
| Flux.relu(x)^2 + 1e-2 | |
| end | |
| function save(name::String, vae::VAE) | |
| BSON.@save "vae.bson" vae | |
| end | |
| function load(v::Type{VAE}, name::String)::VAE | |
| BSON.@load "vae.bson" vae | |
| vae | |
| end | |
| function prepare_data(X) | |
| Flux.unsqueeze(X, 3) | |
| end | |
| function restore_image(decoder_output) | |
| reshape(decoder_output, 28, 28, size(decoder_output, 2)) | |
| end | |
| function create_networks(X; latent_dim=8) | |
| indims = size(X) | |
| encoder_means = Chain( | |
| Flux.flatten, | |
| Dense(784 => 64, Flux.σ), | |
| Dense(64 => 16, Flux.σ), | |
| Dense(16 => latent_dim)) | |
| encoder_cov = Chain( | |
| Flux.flatten, | |
| Dense(784 => 64, Flux.σ), | |
| Dense(64 => 16, Flux.σ), | |
| Dense(16 => latent_dim, encoder_cov_activate)) | |
| decoder_means = Chain( | |
| Dense(latent_dim => 16), | |
| Dense(16 => 64), | |
| Dense(64 => indims[1]*indims[2], Flux.σ), | |
| restore_image | |
| ) | |
| VAE(encoder_means, encoder_cov, decoder_means) | |
| end | |
| function create_conv_networks(X; latent_dim=8) | |
| indims = size(X) | |
| encoder_means = Chain( | |
| prepare_data, | |
| Conv((4,4), 1 => 1), | |
| Flux.flatten, | |
| Dense(25^2 => 16, Flux.σ), | |
| Dense(16 => latent_dim)) | |
| encoder_cov = Chain( | |
| prepare_data, | |
| Conv((4,4), 1 => 1), | |
| Flux.flatten, | |
| Dense(25^2 => 16, Flux.σ), | |
| Dense(16 => latent_dim, encoder_cov_activate)) | |
| decoder_means = Chain( | |
| Dense(latent_dim => 16), | |
| Dense(16 => 64), | |
| Dense(64 => indims[1]*indims[2], Flux.σ), | |
| restore_image | |
| ) | |
| VAE(encoder_means, encoder_cov, decoder_means) | |
| end | |
| function normal_kld(µ1, Σ1, µ2, Σ2) | |
| @assert size(µ1) == size(µ2) | |
| inv_Σ2 = inv(Σ2) | |
| d1 = det(Σ1) | |
| d2 = det(Σ2) | |
| k = length(µ1) | |
| 1/2 * (tr(inv_Σ2 * Σ1) - k + (µ2-µ1)' * inv_Σ2 * (µ2-µ1) + log(d2/d1)) | |
| end | |
| function vae_loss(X, vae, ζs, c::Real) | |
| means = vae.encoder_mean(X) | |
| covs = vae.encoder_cov(X) | |
| z = (ζs .* covs) .+ means | |
| k = size(means, 1) | |
| sum((X .- vae.decoder_mean(z)).^2)./(2*c) + sum( | |
| Distributions.kldivergence( | |
| MvNormal(m, diagm(co)), | |
| MvNormal(zeros(k), I)) | |
| for (m,co) in zip(eachcol(means), eachcol(covs))) | |
| end | |
| function fill_param_dict!(dict, m, prefix) | |
| if m isa Chain | |
| for (i, layer) in enumerate(m.layers) | |
| fill_param_dict!(dict, layer, prefix*"layer_"*string(i)*"/"*string(layer)*"/") | |
| end | |
| else | |
| for fieldname in fieldnames(typeof(m)) | |
| val = getfield(m, fieldname) | |
| if val isa AbstractArray | |
| val = vec(val) | |
| end | |
| dict[prefix*string(fieldname)] = val | |
| end | |
| end | |
| end | |
| function train(;previous_vae=nothing, epochs=100, lr=1e-3, latent_dim=8, c=.1, batchsize=8, shuffle=true) | |
| logger = TBL.TBLogger("log", TBL.tb_overwrite) | |
| X, _ = load_mnist_dataset(:train) | |
| valX, _ = load_mnist_dataset(:test) | |
| ntest = batchsize | |
| testX = @view valX[:,:,1:ntest] | |
| if isnothing(previous_vae) | |
| vae = create_networks(X; latent_dim=latent_dim) | |
| else | |
| vae = previous_vae | |
| end | |
| # Approx. reparam_z | |
| reparam_ζ = MvNormal(zeros(latent_dim), LinearAlgebra.I) | |
| loss(X, ζs) = begin | |
| vae_loss(X, vae, ζs, c) | |
| end | |
| data = DataLoader(X; batchsize=batchsize, shuffle=shuffle) | |
| opt = Flux.Optimise.Adam(lr) | |
| p = Flux.params(vae.encoder_mean, vae.encoder_cov, vae.decoder_mean) | |
| i = 0 | |
| tbcallback() = begin | |
| paramd = Dict{String, Any}() | |
| fill_param_dict!(paramd, vae.encoder_mean, "enc_mean") | |
| fill_param_dict!(paramd, vae.encoder_cov, "enc_cov") | |
| fill_param_dict!(paramd, vae.decoder_mean, "dec_mean") | |
| with_logger(logger) do | |
| @info "model" params=paramd log_step_increment=0 | |
| @info "train" loss=loss((@view X[:,:,1:5ntest]), rand(reparam_ζ, 5ntest)) | |
| @info "test" loss=loss((@view valX[:,:,1:5ntest]), rand(reparam_ζ, 5ntest)) | |
| end | |
| end | |
| for epoch in 1:epochs | |
| i += 1; | |
| println("\n epoch $i: loss $(loss(testX, rand(reparam_ζ, ntest)))") | |
| print(" > ") | |
| j = 0 | |
| for d in data | |
| if j % 100 == 0 | |
| print("$j ") | |
| end | |
| if j % 1000 == 0 | |
| print("(loss: $(loss(testX, rand(reparam_ζ, ntest)))) ") | |
| end | |
| if j % 10000 == 0 | |
| tbcallback() | |
| end | |
| j += batchsize; | |
| # Reparameterization trick | |
| ζs = rand(reparam_ζ, batchsize) | |
| grads = Flux.gradient(() -> loss(d, ζs), p) | |
| Flux.Optimise.update!(opt, p, grads) | |
| end | |
| end | |
| vae | |
| end | |
| function sample_image(inspiration, vae; n=8) | |
| inspiration = Flux.unsqueeze(inspiration, 3) | |
| means = vae.encoder_mean(inspiration) | |
| covs = vae.encoder_cov(inspiration) | |
| dist = MvNormal(means[:,1], diagm(covs[:,1])) | |
| Images.save("inspiration.png", Gray{N0f8}.(inspiration)) | |
| for i in 1:n | |
| z = rand(dist, 1) | |
| Images.save("test$(i).png", Gray{N0f8}.(reshape(vae.decoder_mean(z), 28, 28))) | |
| end | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment