Last active
February 24, 2022 11:50
-
-
Save okiriza/fe874412f540a6f7eb0111c4f6649afe to your computer and use it in GitHub Desktop.
Script for visualizing autoencoder and PCA encoding on MNIST data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import colorlover as cl | |
from plotly import graph_objs as go | |
from plotly import offline | |
from sklearn.decomposition import PCA | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms | |
offline.init_notebook_mode() | |
class AutoEncoder(nn.Module): | |
def __init__(self, code_size): | |
super().__init__() | |
self.code_size = code_size | |
self.enc_cnn_1 = nn.Conv2d(1, 10, kernel_size=5) | |
self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5) | |
self.enc_linear_1 = nn.Linear(4 * 4 * 20, 50) | |
self.enc_linear_2 = nn.Linear(50, self.code_size) | |
self.dec_linear_1 = nn.Linear(self.code_size, 160) | |
self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE) | |
def forward(self, images): | |
code = self.encode(images) | |
reconst = self.decode(code) | |
return reconst, code | |
def encode(self, images): | |
code = self.enc_cnn_1(images) | |
code = F.selu(F.max_pool2d(code, 2)) | |
code = self.enc_cnn_2(code) | |
code = F.selu(F.max_pool2d(code, 2)) | |
code = code.view([images.size(0), -1]) | |
code = F.selu(self.enc_linear_1(code)) | |
code = self.enc_linear_2(code) | |
return code | |
def decode(self, code): | |
reconst = F.selu(self.dec_linear_1(code)) | |
reconst = F.sigmoid(self.dec_linear_2(reconst)) | |
reconst = reconst.view([code.size(0), 1, IMAGE_WIDTH, IMAGE_HEIGHT]) | |
return reconst | |
def viz_code(code, labels): | |
cmap = cl.scales['10']['qual']['Paired'] | |
layout = go.Layout( | |
xaxis=dict(range=[-5, 4], tickfont=dict(size=16)), | |
yaxis=dict(tickfont=dict(size=16)), | |
legend=dict(font=dict(size=16)), | |
) | |
# Create scatter digit by digit | |
traces = [ | |
go.Scatter( | |
x=code[labels == i, 0], | |
y=code[labels == i, 1], | |
mode='markers', | |
name=str(i), | |
marker=dict(color=cmap[i]), | |
) | |
for i in range(10) | |
] | |
fig = go.Figure(data=traces, layout=layout) | |
return fig | |
IMAGE_SIZE = 784 | |
IMAGE_WIDTH = IMAGE_HEIGHT = 28 | |
# Load previously trained autoencoder with code_size = 2 | |
# See https://gist.github.com/okiriza/16ec1f29f5dd7b6d822a0a3f2af39274 | |
autoencoder = AutoEncoder(2) | |
autoencoder.load_state_dict(torch.load('path/to/saved/state_dict.pk')) | |
# Train PCA | |
## Read training data | |
train_data = datasets.MNIST('path/to/data/mnist/', train=True, transform=transforms.ToTensor()) | |
train_loader = torch.utils.data.DataLoader(train_data, shuffle=False, batch_size=len(train_data), num_workers=4, drop_last=False) | |
train_images, _ = next(iter(train_loader)) # Ignore train labels | |
train_images = train_images.numpy() | |
## Fit PCA on training data | |
pca = PCA(n_components=2, whiten=True) | |
pca.fit(train_images.reshape([-1, IMAGE_SIZE])) | |
# Encode test images | |
## Read all test data at once | |
test_data = datasets.MNIST('path/to/data/mnist/', train=False, transform=transforms.ToTensor()) | |
test_loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=len(test_data), num_workers=4, drop_last=False) | |
test_images, test_labels = next(iter(test_loader)) | |
test_labels = test_labels.numpy() | |
## Encode | |
code_ae = autoencoder.encode(Variable(test_images)) | |
code_ae_normed = (code_ae - code_ae.mean(axis=0)) / code_ae.std(axis=0) # For comparable axis to PCA's code | |
code_pca = pca.transform(test_images.numpy().reshape([-1, IMAGE_SIZE])) | |
# Visualize | |
fig_ae = viz_code(code_ae_normed, test_labels) | |
fig_pca = viz_code(code_pca, test_labels) | |
offline.plot(fig_ae, filename='mnist_ae_2-dim.html') | |
offline.plot(fig_pca, filename='mnist_pca_2-dim.html') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment