-
-
Save Coderx7/f9b523202aee1f40b1584b4c6b0e66e0 to your computer and use it in GitHub Desktop.
t-SNE visualization code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Dinesh Jayaraman | |
# Based on code by | |
# Authors: Fabian Pedregosa <[email protected]> | |
# Olivier Grisel <[email protected]> | |
# Mathieu Blondel <[email protected]> | |
# Gael Varoquaux | |
# License: BSD 3 clause (C) INRIA 2011 | |
print(__doc__) | |
from time import time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import offsetbox | |
from matplotlib.patches import Rectangle | |
from sklearn import (manifold, datasets, decomposition, ensemble, | |
discriminant_analysis, random_projection) | |
import scipy.misc | |
import matplotlib.patches as mpatches | |
imresize = scipy.misc.imresize | |
from collections import namedtuple | |
import h5py | |
import ipdb | |
num_samples = 10000 | |
n_neighbors = 30 | |
dsname = 'ModelNet10' | |
datafield = 'view' | |
labelfield = 'labs' | |
#dsname = 'ModelNet30'; datafield = 'data'; labelfield = 'label'; | |
data_h5filename = '%s_1view_train_torchfeed.h5' % dsname | |
feat_h5filename = '3000112_-%s_trainfeatures.h5' % dsname | |
if dsname == 'ModelNet10': | |
classnames = [ | |
'bathtub', | |
'bed', | |
'chair', | |
'desk', | |
'dresser', | |
'monitor', | |
'night_stand', | |
'sofa', | |
'table', | |
'toilet' | |
] | |
else: | |
classnames = [] | |
#ds = datasets.load_digits(n_class=6) | |
ds = namedtuple('dataset', ['data', 'target', 'images']) | |
h5file = h5py.File(data_h5filename, 'r') | |
ds.images = np.squeeze(h5file[datafield][:]).transpose(0, 2, 1) # Nx32x32 | |
ds.target = h5file[labelfield][:].reshape(-1) # N | |
h5file.close() | |
h5file = h5py.File(feat_h5filename, 'r') | |
ds.data = h5file['feat_ip1'][:] # D=Nx256 | |
lab = h5file['cls_labs'][:].reshape(-1) # N | |
assert(np.array_equal(lab, ds.target)) | |
h5file.close() | |
# subsampling the data | |
ds.data = ds.data[:num_samples] | |
ds.target = ds.target[:num_samples].astype(int) | |
num_cls = np.max(ds.target)+1 | |
ds.images = ds.images[:num_samples] | |
X = ds.data | |
y = ds.target | |
n_samples, n_features = X.shape | |
# ipdb.set_trace() | |
#---------------------------------------------------------------------- | |
# Scale and visualize the embedding vectors | |
def plot_embedding(xy, title=None, show_images=True, show_points=True): | |
assert(show_images or show_points) | |
xy_min, xy_max = np.min(xy, 0), np.max(xy, 0) | |
xy = (xy - xy_min) / (xy_max - xy_min) | |
if show_points: | |
fig = plt.figure(figsize=(11, 6)) | |
else: | |
fig = plt.figure(figsize=(12, 6)) | |
ax = plt.subplot(111) | |
if show_points: | |
seen_cls = np.zeros(num_cls, dtype='bool') | |
for i in range(xy.shape[0]): | |
cls = ds.target[i] | |
plt.text(xy[i, 0], xy[i, 1], str(cls+1), | |
color=plt.cm.tab10((cls) / 10.), | |
fontdict={'weight': 'bold', 'size': 9}) | |
patches = [] | |
for clsno in range(num_cls): | |
patches.append( | |
mpatches.Patch( | |
color=plt.cm.tab10( | |
(clsno) / | |
num_cls), | |
label=( | |
'%d: ' % | |
(clsno + | |
1)) + | |
classnames[clsno])) | |
lgd = ax.legend(handles=patches, fontsize=16, bbox_to_anchor=(1.3, 1)) | |
if show_images and hasattr(offsetbox, 'AnnotationBbox'): | |
# only print thumbnails with matplotlib > 1.0 | |
shown_images = np.array([[1., 1.]]) # just something big | |
for i in range(ds.data.shape[0]): | |
dist = np.sum((xy[i] - shown_images) ** 2, 1) | |
if np.min(dist) < 2e-3: | |
# don't show points that are too close | |
continue | |
shown_images = np.r_[shown_images, [xy[i]]] | |
imagebox = offsetbox.AnnotationBbox( | |
offsetbox.OffsetImage( | |
np.pad( | |
ds.images[i], | |
1, | |
'constant'), | |
cmap=plt.cm.gray), | |
xy[i], | |
frameon=False) | |
ax.add_artist(imagebox) | |
lgd = None | |
plt.xticks([]), plt.yticks([]) | |
if title is not None: | |
plt.title(title) | |
return [fig, lgd] | |
# print(fig.get_size_inches()) | |
#---------------------------------------------------------------------- | |
# Plot images of the ds | |
img_h = ds.images.shape[2] | |
pad = 4 | |
n_img_per_row = min(20, int(np.floor(np.sqrt(n_samples)))) | |
img = np.zeros(((img_h+pad) * n_img_per_row, (img_h+pad) * n_img_per_row)) | |
for i in range(n_img_per_row): | |
ix = (img_h+pad) * i + pad/2 | |
for j in range(n_img_per_row): | |
iy = (img_h+pad) * j + pad/2 | |
img[int(ix):int(ix + img_h), int(iy):int(iy + img_h) | |
] = ds.images[int(i * n_img_per_row + j)] | |
plt.imshow(img, cmap=plt.cm.gray) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.title('A selection from the dataset') | |
#---------------------------------------------------------------------- | |
# t-SNE embedding of the ds dataset | |
print("Computing t-SNE embedding") | |
tsne = manifold.TSNE(n_components=2, init='pca', random_state=0) | |
t0 = time() | |
X_tsne = tsne.fit_transform(X) | |
print("plotting") | |
fig1, lgd1 = plot_embedding(X_tsne, | |
title=None, | |
show_images=False, | |
show_points=True) | |
fig2, _ = plot_embedding(X_tsne, | |
title=None, | |
show_images=True, | |
show_points=False) | |
if n_samples > 0: | |
print("saving figures") | |
plt.figure(fig1.number) | |
plt.savefig( | |
'/home/dineshj/Documents/Drafts/oneshot/figs/%s_tsne_noims-%dsamples-%dnbrs-wide.pdf' % | |
(dsname, num_samples, n_neighbors), bbox_extra_artists=( | |
lgd1,), bbox_inches='tight') | |
plt.figure(fig2.number) | |
plt.savefig( | |
'/home/dineshj/Documents/Drafts/oneshot/figs/%s_tsne-%dsamples-%dnbrs-wide.pdf' % | |
(dsname, num_samples, n_neighbors), bbox_inches='tight') | |
print("%.2f s" % (time()-t0)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment