Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created November 13, 2019 07:46
Show Gist options
  • Save vankesteren/96c5e98eba6cff279930e8262570ba84 to your computer and use it in GitHub Desktop.
Save vankesteren/96c5e98eba6cff279930e8262570ba84 to your computer and use it in GitHub Desktop.
Sparse coding for MNIST feature extraction using autodiff
# Sparse coding for MNIST feature extraction using autodiff
using Zygote: gradient
using MLDatasets: MNIST
using LinearAlgebra: Diagonal
using ImageCore
# Goodfellow page 629, equation 19.16, but per pixel
function loss(H::Matrix{Float64}, W::Matrix{Float64})
(sum(abs.(H)) + sum((X - H*W).^2)) / L
end
# features
X = MNIST.convert2features(MNIST.traintensor())'
L = length(X)
# variables
M = 20 # number of hidden features
H = randn(size(X, 1), M)
W = H \ X
# loss
J = loss(H, W)
# gradient descent for H & least squares with weight decay for W
s = 1.0e-2 * L # step size
λ = 4.0e-4 # weight decay
ϵ = 1.0e-10 # stop after
for i = 1:10_000
global H, W, J
H -= s .* gradient(loss, H, W)[1]
W = H \ X .- λ .* W
J_prev = J
J = loss(H, W)
if ((J_prev - J) < ϵ)
break
end
print("Loss ", i, " | ", J, "\n")
end
# check features as images
basis = Diagonal(ones(M)) .* 1000
imgs = reshape(basis*W, M, 28, 28)
MNIST.convert2image(imgs[1, :,:])
MNIST.convert2image(imgs[2, :,:])
MNIST.convert2image(imgs[3, :,:])
MNIST.convert2image(imgs[4, :,:])
MNIST.convert2image(imgs[5, :,:])
MNIST.convert2image(imgs[6, :,:])
MNIST.convert2image(imgs[7, :,:])
MNIST.convert2image(imgs[8, :,:])
MNIST.convert2image(imgs[9, :,:])
MNIST.convert2image(imgs[10,:,:])
MNIST.convert2image(imgs[11,:,:])
MNIST.convert2image(imgs[12,:,:])
MNIST.convert2image(imgs[13,:,:])
MNIST.convert2image(imgs[14,:,:])
MNIST.convert2image(imgs[15,:,:])
MNIST.convert2image(imgs[16,:,:])
MNIST.convert2image(imgs[17,:,:])
MNIST.convert2image(imgs[18,:,:])
MNIST.convert2image(imgs[19,:,:])
MNIST.convert2image(imgs[20,:,:])
@vankesteren
Copy link
Author

Feature number 13 looks like this in my run:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment