Last active
December 4, 2018 11:26
-
-
Save wottpal/ff65bd4a7eee744b7a11b33ebe8c8bd7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! -*- coding: utf-8 -*- | |
import sys | |
import time | |
import numpy as np | |
import gensim | |
import matplotlib.pyplot as plt | |
from matplotlib import font_manager, rc | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import MiniBatchKMeans | |
from scipy.spatial import distance | |
import random | |
import math | |
from adjustText import adjust_text | |
VOCAB_SIZE = 10000 # MAX 3.000.000 | |
# Load the Google-News Model (https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM) | |
model_path = "./GoogleNews-vectors-negative300.bin" | |
model = gensim.models.KeyedVectors.load_word2vec_format( | |
model_path, binary=True) | |
wv = model.wv.syn0 # word vectors (wv) | |
vocabulary = model.wv.vocab # vocabulary (lexicon) | |
# Run TSNE | |
tsne = TSNE(n_components=2, random_state=0) | |
np.set_printoptions(suppress=True) | |
Y = tsne.fit_transform(wv[:VOCAB_SIZE, :]) | |
word_positions = zip(vocabulary, Y[:, 0], Y[:, 1]) | |
def get_word_position(word): | |
"""Finds words location in the vocabulary and tsne-plot""" | |
try: | |
if word in vocabulary: | |
word_index = list(vocabulary).index(word) | |
word_x = Y[word_index, 0] | |
word_y = Y[word_index, 1] | |
print(f"'{word}' at ({word_x},{word_y})") | |
return word_index, word_x, word_y | |
else: | |
return None, None, None | |
# raise ValueError(f"Can't find given word '{word}'") | |
except: | |
return None, None, None | |
def get_random_color(pastel_factor=0.9): | |
return [(x+pastel_factor)/(1.0+pastel_factor) for x in [random.uniform(0, 1.0) for i in [1, 2, 3]]] | |
def save_word_plot(word, max_dist): | |
"""Plots the tSNE surroundings of a given word with a given distance.""" | |
plt.figure() | |
word_index, word_x, word_y = get_word_position(word) | |
if word_index is None or word_x is None or word_y is None: | |
return | |
# Add words with a minimum location to the word of interest to a new list | |
vocab_new = [] | |
x_new = [] | |
y_new = [] | |
for w, x, y in zip(vocabulary, Y[:, 0], Y[:, 1]): | |
dist = distance.euclidean([word_x, word_y], [x, y]) | |
if dist < max_dist: | |
vocab_new.append(w) | |
x_new.append(x) | |
y_new.append(y) | |
# Plot | |
plt.scatter(x_new, y_new, c=get_random_color()) | |
texts = [] | |
for label, x, y in zip(vocab_new, x_new, y_new): | |
texts.append(plt.text(x, y, label)) | |
# plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoord | |
# s='offset points') | |
# plt.show() | |
adjust_text(texts) | |
plt.axis('off') | |
# plt.savefig(f'results/plot_{word}_{max_dist}.png', bbox_inches='tight') | |
filename = f'results/plot_{word}_{max_dist}.pdf' | |
plt.savefig(filename, bbox_inches='tight') | |
plt.close() | |
print(f"Saved '{filename}'") | |
def save_full_plot(show_labels = False, kmeans_clusters = 8): | |
"""Plots the full tSNE world""" | |
bw = kmeans_clusters == 1 | |
plt.figure() | |
kmeans = MiniBatchKMeans(n_clusters=kmeans_clusters) | |
labels = kmeans.fit_predict(Y) | |
colors = [get_random_color() for x in range(kmeans_clusters)] | |
if bw: colors = ["gray"] | |
for idx, word_pos in enumerate(word_positions): | |
label, x, y = word_pos | |
label_color = colors[labels[idx]] | |
plt.plot(x, y, color=label_color, marker='o', markersize=1) | |
if show_labels: plt.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points') | |
plt.axis('off') | |
filename = f'results/full_plot_{"labels_" if show_labels else ""}{VOCAB_SIZE}_{kmeans_clusters}.pdf' | |
plt.savefig(filename, bbox_inches='tight') | |
plt.close() | |
print(f"Saved '{filename}'") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment