Created
April 28, 2018 16:34
-
-
Save pmbaumgartner/adb33aa486b77ab58eb3df265393195d to your computer and use it in GitHub Desktop.
Load Google News Word2Vec, Reduce Dimension with UMAP, and plot with plot.ly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gensim.downloader as gensim_api | |
import umap | |
import requests | |
import pandas as pd | |
from numpy import log10 | |
import plotly | |
import plotly.graph_objs as go | |
w2v_model = gensim_api.load('word2vec-google-news-300') | |
def read_1w_corpus(r, sep="\t"): | |
for line in open(r): | |
yield line.split(sep) | |
vocabulary = set(w2v_model.vocab) | |
relevant_words = [(word, count) for (word, count) in read_1w_corpus('/Users/pbaumgartner/data/count_1w.txt') if word in vocabulary][:100000] | |
model_reduced = w2v_model[[w[0] for w in relevant_words]] | |
embedding = umap.UMAP(random_state=666).fit_transform(model_reduced) | |
d = pd.DataFrame(embedding, columns=['c1', 'c2']) | |
d['word'] = [w[0] for w in relevant_words] | |
d['count'] = [int(w[1]) for w in relevant_words] | |
d['log_count'] = d['count'].apply(log10) | |
def build_tooltip(row): | |
full_string = ['<b>Word:</b> ', row['word'], | |
'<br>', | |
'<b>Count:</b> ', "{:,}".format((row['count'])), | |
'<br>', | |
'<b>Magnitude:</b> ', str(round(row['log_count']))] | |
return ''.join(full_string) | |
d['tooltip'] = d.apply(build_tooltip, axis=1) | |
trace = go.Scattergl( | |
x = d['c1'], | |
y = d['c2'], | |
name = 'Embedding', | |
mode = 'markers', | |
marker = dict( | |
color = d['log_count'], | |
colorscale='Viridis', | |
size = 6, | |
line = dict( | |
width = 0.5, | |
), | |
opacity=0.75 | |
), | |
text=d['tooltip'] | |
) | |
layout = dict(title = 'Word2Vec Google News- 2D UMAP Embeddings', | |
yaxis = dict(zeroline = False), | |
xaxis = dict(zeroline = False), | |
hovermode = 'closest' | |
) | |
fig = go.Figure(data=[trace], layout=layout) | |
chart = plotly.offline.plot(fig, filename='w2v-umap.html') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment