Skip to content

Instantly share code, notes, and snippets.

@okiriza
Last active February 24, 2022 11:50
Show Gist options
  • Save okiriza/fe874412f540a6f7eb0111c4f6649afe to your computer and use it in GitHub Desktop.
Save okiriza/fe874412f540a6f7eb0111c4f6649afe to your computer and use it in GitHub Desktop.
Script for visualizing autoencoder and PCA encoding on MNIST data
import colorlover as cl
from plotly import graph_objs as go
from plotly import offline
from sklearn.decomposition import PCA
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
offline.init_notebook_mode()
class AutoEncoder(nn.Module):
def __init__(self, code_size):
super().__init__()
self.code_size = code_size
self.enc_cnn_1 = nn.Conv2d(1, 10, kernel_size=5)
self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
self.enc_linear_1 = nn.Linear(4 * 4 * 20, 50)
self.enc_linear_2 = nn.Linear(50, self.code_size)
self.dec_linear_1 = nn.Linear(self.code_size, 160)
self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)
def forward(self, images):
code = self.encode(images)
reconst = self.decode(code)
return reconst, code
def encode(self, images):
code = self.enc_cnn_1(images)
code = F.selu(F.max_pool2d(code, 2))
code = self.enc_cnn_2(code)
code = F.selu(F.max_pool2d(code, 2))
code = code.view([images.size(0), -1])
code = F.selu(self.enc_linear_1(code))
code = self.enc_linear_2(code)
return code
def decode(self, code):
reconst = F.selu(self.dec_linear_1(code))
reconst = F.sigmoid(self.dec_linear_2(reconst))
reconst = reconst.view([code.size(0), 1, IMAGE_WIDTH, IMAGE_HEIGHT])
return reconst
def viz_code(code, labels):
cmap = cl.scales['10']['qual']['Paired']
layout = go.Layout(
xaxis=dict(range=[-5, 4], tickfont=dict(size=16)),
yaxis=dict(tickfont=dict(size=16)),
legend=dict(font=dict(size=16)),
)
# Create scatter digit by digit
traces = [
go.Scatter(
x=code[labels == i, 0],
y=code[labels == i, 1],
mode='markers',
name=str(i),
marker=dict(color=cmap[i]),
)
for i in range(10)
]
fig = go.Figure(data=traces, layout=layout)
return fig
IMAGE_SIZE = 784
IMAGE_WIDTH = IMAGE_HEIGHT = 28
# Load previously trained autoencoder with code_size = 2
# See https://gist.github.com/okiriza/16ec1f29f5dd7b6d822a0a3f2af39274
autoencoder = AutoEncoder(2)
autoencoder.load_state_dict(torch.load('path/to/saved/state_dict.pk'))
# Train PCA
## Read training data
train_data = datasets.MNIST('path/to/data/mnist/', train=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_data, shuffle=False, batch_size=len(train_data), num_workers=4, drop_last=False)
train_images, _ = next(iter(train_loader)) # Ignore train labels
train_images = train_images.numpy()
## Fit PCA on training data
pca = PCA(n_components=2, whiten=True)
pca.fit(train_images.reshape([-1, IMAGE_SIZE]))
# Encode test images
## Read all test data at once
test_data = datasets.MNIST('path/to/data/mnist/', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=len(test_data), num_workers=4, drop_last=False)
test_images, test_labels = next(iter(test_loader))
test_labels = test_labels.numpy()
## Encode
code_ae = autoencoder.encode(Variable(test_images))
code_ae_normed = (code_ae - code_ae.mean(axis=0)) / code_ae.std(axis=0) # For comparable axis to PCA's code
code_pca = pca.transform(test_images.numpy().reshape([-1, IMAGE_SIZE]))
# Visualize
fig_ae = viz_code(code_ae_normed, test_labels)
fig_pca = viz_code(code_pca, test_labels)
offline.plot(fig_ae, filename='mnist_ae_2-dim.html')
offline.plot(fig_pca, filename='mnist_pca_2-dim.html')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment