Created
April 12, 2010 14:52
-
-
Save ogrisel/363638 to your computer and use it in GitHub Desktop.
t-SNE wrapper to output SVG maps
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
*.pyc | |
mnist2500* | |
build/ | |
pip-log.txt | |
text-documents/ | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Experimental script to semantically map text/PDF document using t-SNE | |
See: http://homepage.tudelft.nl/19j49/t-SNE.html | |
Related: http://github.com/turian/textSNE | |
You will need: | |
- lxml | |
- numpy | |
- nltk | |
- http://pypi.python.org/pypi/svg.charts/ | |
(fix the setup.py to point to the correct readme file). | |
This file is release under the MIT license but the t-SNE code is for research | |
only usage (this the authors page for details). | |
""" | |
import tsne | |
from nltk.stem.porter import PorterStemmer | |
from nltk import word_tokenize | |
from nltk import sent_tokenize | |
from svg.charts.plot import Plot | |
import cssutils | |
import pkg_resources | |
import numpy as np | |
import os | |
from itertools import izip | |
from lxml import etree | |
def extract_text(pdf_folder, txt_folder): | |
if not os.path.exists(txt_folder): | |
os.makedirs(txt_folder) | |
for pdf_filename in os.listdir(pdf_folder): | |
basename, ext = os.path.splitext(pdf_filename) | |
if ext.lower() != ".pdf": | |
print "skipping", pdf_filename | |
continue | |
pdf_filepath = os.path.join(pdf_folder, pdf_filename) | |
text_filepath = os.path.join(txt_folder, basename + ".txt") | |
cmd = "pdftotext %s %s > /dev/null 2>&1" % (pdf_filepath, text_filepath) | |
print cmd | |
os.system(cmd) | |
class HashingVectorizer(object): | |
"""Compute term frequencies vectors using hashed term space""" | |
def __init__(self, dim=5000, stemmer=None, probes=3): | |
self.dim = dim | |
self.probes = probes | |
self.stemmer = stemmer if stemmer is not None else PorterStemmer() | |
def hash(self, term, probe=0): | |
h = hash(self.stemmer.stem(term.lower())) | |
return abs(hash(term) + hash(probe * " ")) % self.dim | |
def term_frequencies(self, files): | |
"""Tokenize documents and hash the terms and compute the term freqs""" | |
if isinstance(files, basestring): | |
folder = files | |
files = os.listdir(folder) | |
files.sort() | |
files = [os.path.join(folder, f) for f in files] | |
freqs = np.zeros((len(files), self.dim)) | |
for i, filepath in enumerate(files): | |
print "analysing file %d/%d: %s" % (i + 1, len(files), filepath) | |
sentences = sent_tokenize(file(filepath).read()) | |
for sentence in sentences: | |
for term in word_tokenize(sentence): | |
# TODO add support for cooccurence tokens in a sentence | |
# window | |
for probe in xrange(self.probes): | |
freqs[i][self.hash(term, probe)] += 1.0 | |
freqs[i] /= freqs[i].sum() | |
return freqs | |
class SemanticMap(Plot): | |
"""2D vector map of documents projected using the t-SNE algorithm""" | |
draw_lines_between_points = False | |
show_x_guidelines = False | |
show_y_guidelines = False | |
show_x_title = False | |
show_y_title = False | |
show_x_labels = False | |
show_y_labels = False | |
show_label_popup = True | |
width = 1280 | |
height = 1280 | |
def __init__(self, positions, text_labels=None, categories=None, urls=None): | |
self.x = positions[:, 0] | |
self.y = positions[:, 1] | |
self.positions = positions | |
dx = self.x.max() - self.x.min() | |
dy = self.y.max() - self.y.min() | |
super(SemanticMap, self).__init__({ | |
'min_x_value': self.x.min() - 0.03 * dx, | |
'min_y_value': self.y.min() - 0.03 * dy, | |
'max_x_value': self.x.max(), | |
'max_y_value': self.y.max(), | |
}) | |
# split data by category | |
if categories is None: | |
categories = ["Default category"] * len(positions) | |
if text_labels is None: | |
text_labels = [None] * len(positions) | |
if urls is None: | |
urls = [None] * len(positions) | |
categorized = {} | |
self.point_labels = {} | |
self.urls = {} | |
for pos, category, label, url in izip(positions, categories, | |
text_labels, urls): | |
categorized.setdefault(category, []).append(pos[0]) | |
categorized.setdefault(category, []).append(pos[1]) | |
self.point_labels[tuple(pos)] = label | |
self.urls[tuple(pos)] = url | |
# add the data to the grap | |
unique_categories = categorized.keys() | |
unique_categories.sort() | |
for category in unique_categories: | |
self.add_data({'data': categorized[category], 'title': category}) | |
def draw_data_points(self, line, data_points, graph_points): | |
if not self.show_data_points \ | |
and not self.show_data_values: return | |
for ((dx,dy),(gx,gy)) in izip(data_points, graph_points): | |
if self.show_data_points: | |
etree.SubElement(self.graph, 'circle', { | |
'cx': str(gx), | |
'cy': str(gy), | |
'r': '3', | |
'class': 'dataPoint%(line)s' % vars()}) | |
if self.show_label_popup and self.point_labels[(dx, dy)]: | |
self.add_popup(gx, gy, self.point_labels[(dx, dy)], | |
self.urls[(dx, dy)]) | |
def add_popup(self, x, y, label, url): | |
"Adds pop-up point information to a graph." | |
txt_width = len(label) * self.font_size * 0.6 + 10 | |
tx = x + [10, -10][int(x + txt_width > self.width)] | |
anchor = ['start', 'end'][x + txt_width > self.width] | |
style = 'fill: #000; text-anchor: %s;' % anchor | |
id = 'label-%s-%s' % (x, y) | |
t = etree.SubElement(self.foreground, 'text', { | |
'x': str(tx), | |
'y': str(y - self.font_size), | |
'class': 'dataPointLabel', | |
'visibility': 'hidden', | |
'style': style, | |
'id': id | |
}) | |
t.text = label | |
# Note, prior to the etree conversion, this circle element was never | |
# added to anything (now it's added to the foreground) | |
visibility = "document.getElementById('%s').setAttribute('visibility', '%%s')" % id | |
if url is not None: | |
parent = etree.SubElement(self.foreground, 'a', { | |
'{http://www.w3.org/1999/xlink}href': url, | |
}) | |
else: | |
parent = self.foreground | |
t = etree.SubElement(parent, 'circle', { | |
'cx': str(x), | |
'cy': str(y), | |
'r': '5', | |
'style': 'opacity: 0;', | |
'onmouseover': visibility % 'visible', | |
'onmouseout': visibility % 'hidden', | |
}) | |
def draw_graph(self): | |
"""Simple background without axis""" | |
transform = 'translate (%s %s)' % (self.border_left, self.border_top) | |
self.graph = etree.SubElement(self.root, 'g', transform=transform) | |
etree.SubElement(self.graph, 'rect', { | |
'x': '0', | |
'y': '0', | |
'width': str(self.graph_width), | |
'height': str(self.graph_height), | |
'class': 'graphBackground' | |
}) | |
@staticmethod | |
def load_resource_stylesheet(name, subs=dict()): | |
css_stream = pkg_resources.resource_stream('doc_tsne', name) | |
css_string = css_stream.read() | |
css_string = css_string % subs | |
sheet = cssutils.parseString(css_string) | |
return sheet | |
if __name__ == "__main__": | |
folder = "/home/ogrisel/Desktop/arxiv-txt" | |
files = os.listdir(folder) | |
files.sort() | |
filepaths = [os.path.join(folder, f) for f in files] | |
labels = [f.rsplit("-", 1)[0] for f in files] | |
int_labels = [abs(hash(l)) for l in labels] | |
#freqs = HashingVectorizer().term_frequencies(filepaths) | |
#projected = tsne.tsne(freqs, perplexity=20.0) | |
#m = SemanticMap(projected, categories=labels) | |
#file(r'/tmp/out.svg', 'w').write(m).burn()) | |
#plt.scatter(projected[:,0], projected[:,1], 20.0, int_labels) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
$Id: graph.css 81 2009-09-01 02:04:44Z jaraco $ | |
Base styles for svg.charts.Graph | |
*/ | |
.svgBackground{ | |
fill:#ffffff; | |
} | |
.graphBackground{ | |
fill:#f0f0f0; | |
} | |
/* graphs titles */ | |
.mainTitle{ | |
text-anchor: middle; | |
fill: #000000; | |
font-size: %(title_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.subTitle{ | |
text-anchor: middle; | |
fill: #999999; | |
font-size: %(subtitle_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.axis{ | |
stroke: #000000; | |
stroke-width: 1px; | |
} | |
.guideLines{ | |
stroke: #666666; | |
stroke-width: 1px; | |
stroke-dasharray: 5,5; | |
} | |
.xAxisLabels{ | |
text-anchor: middle; | |
fill: #000000; | |
font-size: %(x_label_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.yAxisLabels{ | |
text-anchor: end; | |
fill: #000000; | |
font-size: %(y_label_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.xAxisTitle{ | |
text-anchor: middle; | |
fill: #ff0000; | |
font-size: %(x_title_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.yAxisTitle{ | |
fill: #ff0000; | |
text-anchor: middle; | |
font-size: %(y_title_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.dataPointLabel{ | |
fill: #000000; | |
text-anchor:middle; | |
font-size: 10px; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} | |
.staggerGuideLine{ | |
fill: none; | |
stroke: #000000; | |
stroke-width: 0.5px; | |
} | |
.keyText{ | |
fill: #000000; | |
text-anchor:start; | |
font-size: %(key_font_size)dpx; | |
font-family: "Arial", sans-serif; | |
font-weight: normal; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
$Id: plot.css 81 2009-09-01 02:04:44Z jaraco $ | |
default line styles | |
*/ | |
.line1{ | |
fill: none; | |
stroke: #ff0000; | |
stroke-width: 1px; | |
} | |
.line2{ | |
fill: none; | |
stroke: #0000ff; | |
stroke-width: 1px; | |
} | |
.line3{ | |
fill: none; | |
stroke: #00ff00; | |
stroke-width: 1px; | |
} | |
.line4{ | |
fill: none; | |
stroke: #ffcc00; | |
stroke-width: 1px; | |
} | |
.line5{ | |
fill: none; | |
stroke: #00ccff; | |
stroke-width: 1px; | |
} | |
.line6{ | |
fill: none; | |
stroke: #ff00ff; | |
stroke-width: 1px; | |
} | |
.line7{ | |
fill: none; | |
stroke: #00ffff; | |
stroke-width: 1px; | |
} | |
.line8{ | |
fill: none; | |
stroke: #ffff00; | |
stroke-width: 1px; | |
} | |
.line9{ | |
fill: none; | |
stroke: #cc6666; | |
stroke-width: 1px; | |
} | |
.line10{ | |
fill: none; | |
stroke: #663399; | |
stroke-width: 1px; | |
} | |
.line11{ | |
fill: none; | |
stroke: #339900; | |
stroke-width: 1px; | |
} | |
.line12{ | |
fill: none; | |
stroke: #9966FF; | |
stroke-width: 1px; | |
} | |
/* default fill styles */ | |
.fill1{ | |
fill: #cc0000; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill2{ | |
fill: #0000cc; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill3{ | |
fill: #00cc00; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill4{ | |
fill: #ffcc00; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill5{ | |
fill: #00ccff; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill6{ | |
fill: #ff00ff; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill7{ | |
fill: #00ffff; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill8{ | |
fill: #ffff00; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill9{ | |
fill: #cc6666; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill10{ | |
fill: #663399; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill11{ | |
fill: #339900; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
.fill12{ | |
fill: #9966FF; | |
fill-opacity: 0.2; | |
stroke: none; | |
} | |
/* default line styles */ | |
.key1,.dataPoint1{ | |
fill: #ff0000; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key2,.dataPoint2{ | |
fill: #0000ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key3,.dataPoint3{ | |
fill: #00ff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key4,.dataPoint4{ | |
fill: #ffcc00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key5,.dataPoint5{ | |
fill: #00ccff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key6,.dataPoint6{ | |
fill: #ff00ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key7,.dataPoint7{ | |
fill: #00ffff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key8,.dataPoint8{ | |
fill: #ffff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key9,.dataPoint9{ | |
fill: #cc6666; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key10,.dataPoint10{ | |
fill: #663399; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key11,.dataPoint11{ | |
fill: #ff0000; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key12,.dataPoint12{ | |
fill: #0000ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key13,.dataPoint13{ | |
fill: #00ff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key14,.dataPoint14{ | |
fill: #ffcc00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key15,.dataPoint15{ | |
fill: #00ccff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key16,.dataPoint16{ | |
fill: #ff00ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key17,.dataPoint17{ | |
fill: #00ffff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key18,.dataPoint18{ | |
fill: #ffff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key19,.dataPoint19{ | |
fill: #cc6666; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key20,.dataPoint20{ | |
fill: #663399; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key21,.dataPoint21{ | |
fill: #ff0000; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key22,.dataPoint22{ | |
fill: #0000ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key23,.dataPoint23{ | |
fill: #00ff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key24,.dataPoint24{ | |
fill: #ffcc00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key25,.dataPoint25{ | |
fill: #00ccff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key26,.dataPoint26{ | |
fill: #ff00ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key27,.dataPoint27{ | |
fill: #00ffff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key28,.dataPoint28{ | |
fill: #ffff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key29,.dataPoint29{ | |
fill: #cc6666; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key30,.dataPoint30{ | |
fill: #663399; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key31,.dataPoint31{ | |
fill: #ff0000; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key32,.dataPoint32{ | |
fill: #0000ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key33,.dataPoint33{ | |
fill: #00ff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key34,.dataPoint34{ | |
fill: #ffcc00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key35,.dataPoint35{ | |
fill: #00ccff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key36,.dataPoint36{ | |
fill: #ff00ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key37,.dataPoint37{ | |
fill: #00ffff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key38,.dataPoint38{ | |
fill: #ffff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key39,.dataPoint39{ | |
fill: #cc6666; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key40,.dataPoint40{ | |
fill: #663399; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key41,.dataPoint41{ | |
fill: #ff0000; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key42,.dataPoint42{ | |
fill: #0000ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key43,.dataPoint43{ | |
fill: #00ff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key44,.dataPoint44{ | |
fill: #ffcc00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key45,.dataPoint45{ | |
fill: #00ccff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key46,.dataPoint46{ | |
fill: #ff00ff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key47,.dataPoint47{ | |
fill: #00ffff; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key48,.dataPoint48{ | |
fill: #ffff00; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.key49,.dataPoint49{ | |
fill: #cc6666; | |
stroke: none; | |
stroke-width: 1px; | |
} | |
.constantLine{ | |
color: navy; | |
stroke: navy; | |
stroke-width: 1px; | |
stroke-dasharray: 9,1,1; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# tsne.py | |
# | |
# Implementation of t-SNE in Python. The implementation was tested on Python 2.5.1, and it requires a working | |
# installation of NumPy. The implementation comes with an example on the MNIST dataset. In order to plot the | |
# results of this example, a working installation of matplotlib is required. | |
# The example can be run by executing: ipython tsne.py -pylab | |
# | |
# | |
# Created by Laurens van der Maaten on 20-12-08. | |
# Copyright (c) 2008 Tilburg University. All rights reserved. | |
# | |
# Modified by Joseph Turian: | |
# * Use psyco if available. | |
# * Added parameter use_pca, with default False. NB this changes the default behavior. | |
# TODO: | |
# * Make tsne.pca == calc_tsne.PCA | |
# Modified by Olivier Grisel: | |
# * Make it possible to ctrl-C to early stop | |
# * Cosmits | |
# | |
import numpy as Math | |
import pylab as Plot | |
import sys | |
try: | |
import psyco | |
psyco.full() | |
print >> sys.stderr, "psyco is usable!" | |
except: | |
print >> sys.stderr, "No psyco" | |
def Hbeta(D = Math.array([]), beta = 1.0): | |
"""Compute the perplexity and the P-row for a specific value of the precision of a Gaussian distribution.""" | |
# Compute P-row and corresponding perplexity | |
P = Math.exp(-D.copy() * beta) | |
sumP = sum(P) | |
H = Math.log(sumP) + beta * Math.sum(D * P) / sumP | |
P = P / sumP | |
return H, P | |
def x2p(X = Math.array([]), tol = 1e-5, perplexity = 30.0): | |
"""Performs a binary search to get P-values in such a way that each conditional Gaussian has the same perplexity.""" | |
# Initialize some variables | |
print "Computing pairwise distances..." | |
(n, d) = X.shape | |
sum_X = Math.sum(Math.square(X), 1) | |
D = Math.add(Math.add(-2 * Math.dot(X, X.T), sum_X).T, sum_X) | |
P = Math.zeros((n, n)) | |
beta = Math.ones((n, 1)) | |
logU = Math.log(perplexity) | |
# Loop over all datapoints | |
for i in range(n): | |
# Print progress | |
if i % 500 == 0: | |
print "Computing P-values for point ", i, " of ", n, "..." | |
# Compute the Gaussian kernel and entropy for the current precision | |
betamin = -Math.inf | |
betamax = Math.inf | |
Di = D[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))] | |
(H, thisP) = Hbeta(Di, beta[i]) | |
# Evaluate whether the perplexity is within tolerance | |
Hdiff = H - logU | |
tries = 0 | |
while Math.abs(Hdiff) > tol and tries < 50: | |
# If not, increase or decrease precision | |
if Hdiff > 0: | |
betamin = beta[i] | |
if betamax == Math.inf or betamax == -Math.inf: | |
beta[i] = beta[i] * 2 | |
else: | |
beta[i] = (beta[i] + betamax) / 2 | |
else: | |
betamax = beta[i] | |
if betamin == Math.inf or betamin == -Math.inf: | |
beta[i] = beta[i] / 2 | |
else: | |
beta[i] = (beta[i] + betamin) / 2 | |
# Recompute the values | |
(H, thisP) = Hbeta(Di, beta[i]) | |
Hdiff = H - logU | |
tries = tries + 1 | |
# Set the final row of P | |
P[i, Math.concatenate((Math.r_[0:i], Math.r_[i+1:n]))] = thisP | |
# Return final P-matrix | |
print "Mean value of sigma: ", Math.mean(Math.sqrt(1 / beta)) | |
return P | |
def pca(X = Math.array([]), no_dims = 50): | |
"""Runs PCA on the NxD array X in order to reduce its dimensionality to no_dims dimensions.""" | |
print "Preprocessing the data using PCA..." | |
(n, d) = X.shape | |
X = X - Math.tile(Math.mean(X, 0), (n, 1)) | |
(l, M) = Math.linalg.eig(Math.dot(X.T, X)) | |
Y = Math.dot(X, M[:,0:no_dims]) | |
return Y | |
def tsne(X = Math.array([]), no_dims = 2, initial_dims = 50, perplexity = 30.0, use_pca=False): | |
"""Runs t-SNE on the dataset in the NxD array X to reduce its dimensionality to no_dims dimensions. | |
The syntaxis of the function is Y = tsne.tsne(X, no_dims, perplexity), where X is an NxD NumPy array.""" | |
# Check inputs | |
if X.dtype != "float64": | |
print "Error: array X should have type float64." | |
return -1 | |
#if no_dims.__class__ != "<type 'int'>": # doesn't work yet! | |
# print "Error: number of dimensions should be an integer." | |
# return -1 | |
# Initialize variables | |
if use_pca: | |
X = pca(X, initial_dims) | |
(n, d) = X.shape | |
max_iter = 5000 | |
initial_momentum = 0.5 | |
final_momentum = 0.8 | |
eta = 500 | |
min_gain = 0.01 | |
Y = Math.random.randn(n, no_dims) | |
dY = Math.zeros((n, no_dims)) | |
iY = Math.zeros((n, no_dims)) | |
gains = Math.ones((n, no_dims)) | |
# Compute P-values | |
P = x2p(X, 1e-5, perplexity) | |
P = P + Math.transpose(P) | |
P = P / Math.sum(P) | |
P = P * 4; # early exaggeration | |
P = Math.maximum(P, 1e-12) | |
try: | |
# Run iterations | |
for iter in range(max_iter): | |
# Compute pairwise affinities | |
sum_Y = Math.sum(Math.square(Y), 1) | |
num = 1 / (1 + Math.add(Math.add(-2 * Math.dot(Y, Y.T), sum_Y).T, sum_Y)) | |
num[range(n), range(n)] = 0 | |
Q = num / Math.sum(num) | |
Q = Math.maximum(Q, 1e-12) | |
# Compute gradient | |
PQ = P - Q | |
for i in range(n): | |
dY[i,:] = Math.sum(Math.tile(PQ[:,i] * num[:,i], (no_dims, 1)).T * (Y[i,:] - Y), 0) | |
# Perform the update | |
if iter < 20: | |
momentum = initial_momentum | |
else: | |
momentum = final_momentum | |
gains = (gains + 0.2) * ((dY > 0) != (iY > 0)) + (gains * 0.8) * ((dY > 0) == (iY > 0)) | |
gains[gains < min_gain] = min_gain | |
iY = momentum * iY - eta * (gains * dY) | |
Y = Y + iY | |
Y = Y - Math.tile(Math.mean(Y, 0), (n, 1)) | |
# Compute current value of cost function | |
if (iter + 1) % 100 == 0: | |
C = Math.sum(P * Math.log(P / Q)) | |
print "Iteration ", (iter + 1), ": error is ", C | |
# Stop lying about P-values | |
if iter == 100: | |
P = P / 4 | |
except KeyboardInterrupt: | |
print >> sys.stderr, "early stopping by user" | |
# Return solution | |
return Y | |
if __name__ == "__main__": | |
print "Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset." | |
print "Running example on 2,500 MNIST digits..." | |
X = Math.loadtxt("mnist2500_X.txt") | |
labels = Math.loadtxt("mnist2500_labels.txt") | |
Y = tsne(X, 2, 50, 20.0) | |
Plot.scatter(Y[:,0], Y[:,1], 20, labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment