Skip to content

Instantly share code, notes, and snippets.

@dermesser
Created February 26, 2022 12:50
Show Gist options
  • Save dermesser/f001b444a53d514061a16dbaaf8e9036 to your computer and use it in GitHub Desktop.
Save dermesser/f001b444a53d514061a16dbaaf8e9036 to your computer and use it in GitHub Desktop.
MNIST autoencoder network with ~251k weights in Flux.jl, running on CPU and GPU.
using MKL
using CUDA
import Flux
import MLDatasets
import Images
import BSON: @load,@save
import NNlib
using Dates
function logn(args...)
println(now(), " ", args...)
end
DIMS = (28, 28)
function build_model()
encoder = Flux.Chain(
Flux.Conv((4,4), 1 => 16, Flux.relu, stride=2, pad=Flux.SamePad()),
Flux.Conv((4,4), 16 => 16, Flux.relu, stride=2, pad=Flux.SamePad()),
Flux.Conv((4,4), 16 => 16, Flux.relu, stride=1, pad=Flux.SamePad()),
Flux.Conv((4,4), 16 => 1, Flux.relu, stride=1, pad=0),
Flux.flatten,
Flux.Dense(16, 4, Flux.sigmoid))
decoder = Flux.Chain(
Flux.Dense(4, 36, Flux.relu),
x -> reshape(x, 6, 6, 1, size(x, 2)),
Flux.ConvTranspose((3,3), 1 => 32, Flux.relu),
Flux.ConvTranspose((3,3), 32 => 32, Flux.relu),
#Flux.ConvTranspose((3,3), 64 => 64, Flux.relu),
Flux.Conv((3,3), 32 => 16, Flux.relu),
Flux.Conv((3,3), 16 => 8, Flux.relu),
Flux.flatten,
Flux.Dense(288, 784, Flux.sigmoid),
#Flux.Dense(784, 784, Flux.sigmoid),
x -> reshape(x, 28, 28, 1, size(x, 2)))
Flux.Chain(encoder, decoder) |> Flux.gpu
end
function save_model(m)
m = Flux.cpu(m)
encoder = m[1]
decoder = m[2]
@save "model.bson" encoder decoder
end
function load_model()
if stat("model.bson").nlink > 0
logn("Loading model from file...")
@load "model.bson" encoder decoder
return Flux.Chain(encoder, decoder) |> Flux.gpu
end
logn("No model.bson found: creating new model")
return build_model()
end
function make_loss(model)::Function
return (data) -> Flux.Losses.mse(model(data), data)
end
function train(model, data, loss, epochs=100)
params = Flux.params(model)
cb() = logn("Current loss: $(loss(data))")
dl = Flux.DataLoader(data, batchsize=32, shuffle=true, partial=true)
for i in 1:epochs
logn("Epoch $i")
Flux.train!(loss, params, dl, Flux.ADAM(1e-4), cb=Flux.throttle(cb, 60))
save_model(model)
end
end
function main()
m = load_model() |> Flux.gpu
l = make_loss(m)
data, _ = MLDatasets.MNIST.testdata()
logn(size(data))
data = Flux.gpu(convert.(Float32, Flux.unsqueeze(data, 3)))
train(m, data, l)
end
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment