Created
January 7, 2022 22:48
-
-
Save karlrohe/70eab491de35ca24e21b74bed3642df8 to your computer and use it in GitHub Desktop.
PCA on mnist handwritten 2's
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PCA on n=6990 images of handwritten 2's, each with d = 784 pixels. | |
# install.packages("remotes") | |
# remotes::install_github("jlmelville/snedata") | |
# thank you jlmelville for making this data so easy to access! | |
library(snedata) | |
library(magrittr) | |
library(Matrix) | |
library(rARPACK) | |
# get the data: | |
mnist <- download_mnist() | |
# code to plot an image of a hand written digit: | |
show_digit <- function(arr784, col=gray(12:1/12), ...) { | |
# I fiddled with this code: | |
# source("https://gist.githubusercontent.com/brendano/39760/raw/22467aa8a5d104add5e861ce91ff5652c6b271b6/gistfile1.txt") | |
# thank you brendano on github! | |
image((matrix(as.vector(as.numeric(arr784)[1:784]), nrow=28)[,28:1]), col=col, ...) | |
} | |
# here is one example: | |
show_digit(mnist[1,]) | |
# Here are 25 examples: | |
par(mfrow = c(5,5), mar = c(0,0,0,0), | |
xaxt='n', | |
yaxt='n', | |
ann=FALSE) | |
for(i in 1:25) show_digit(mnist[i,]) | |
images_of_selected_digit = mnist$Label %in% c("2") # select the two's | |
# images_of_selected_digit = mnist$Label %in% c("8") # select the two's | |
x = mnist[images_of_selected_digit,1:784] %>% as.matrix | |
# this matrix has 6990 rows and 784 columns. | |
dim(x) | |
# note: even though we think of images as rectangle-shaped and we also think | |
# of matrices as rectangle-shaped, here we are using each image is a *vector* | |
# By doing this, we discard some information... that some pixels are next to one another. | |
# compute PCs. | |
# I use rARPACK because it will be much faster, when we only need a couple PCs | |
# I do not scale the data because some pixels have zero variance. | |
# others have very very small variance. | |
# scaling would divide by zero and these small numbers. | |
# this would be a bad idea! | |
# Because we don't scale, only center, | |
# we study the "covariance" instead of "correlation" | |
s = scale(x, scale = F) %>% rARPACK::svds(k = 2) | |
# if you want the screeplot, you will want more k... | |
# s100 = scale(x, scale = F) %>% rARPACK::svds(k = 100) | |
# plot(s100$d^2) | |
# this is the first pc: | |
pc1 = s$u[,1] | |
dd = density(pc1) # kernel density estimate of the Xhat's | |
# the first 784 columns are pixels. The last column is pc1 | |
dat = cbind(x,s$u[,1]) | |
vals = seq(from = min(pc1), to = max(pc1), len = 100) | |
dif = dd$bw # this is the bandwidth from the kernel density estimate | |
for(i in 1:200){ | |
frame = i | |
if(i>100){ | |
i = 200-i | |
} | |
#identify the images in a region around vals[i] | |
lower_pc = vals[i]-dif | |
upper_pc = vals[i]+dif | |
which_images = (dat[,785] > lower_pc) & (dat[,785] < upper_pc) | |
# if there are images to plot, then plot them: | |
if(sum(which_images)>0){ | |
png(file = paste("video/",frame,".png", sep=""),width = 1100/2, height = 850/2) | |
par(mar = rep(0,4), mfrow = c(1,2)) | |
if(sum(which_images)>1) show_digit(colMeans(dat[which_images,1:784])) | |
if(sum(which_images)==1) show_digit(dat[which_images,1:784]) | |
plot(dd, main = "", ylab = "", yaxt = "n") | |
lines(lower_pc*c(1,1), c(-99999,99999), col = "grey") | |
lines(upper_pc*c(1,1), c(-99999,99999), col = "grey") | |
} | |
dev.off() | |
} | |
# To make the video, I go to quickTime -> File -> open image sequence. | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment