Skip to content

Instantly share code, notes, and snippets.

@pmbaumgartner
Created April 28, 2018 16:34
Show Gist options
  • Save pmbaumgartner/adb33aa486b77ab58eb3df265393195d to your computer and use it in GitHub Desktop.
Save pmbaumgartner/adb33aa486b77ab58eb3df265393195d to your computer and use it in GitHub Desktop.
Load Google News Word2Vec, Reduce Dimension with UMAP, and plot with plot.ly
import gensim.downloader as gensim_api
import umap
import requests
import pandas as pd
from numpy import log10
import plotly
import plotly.graph_objs as go
w2v_model = gensim_api.load('word2vec-google-news-300')
def read_1w_corpus(r, sep="\t"):
for line in open(r):
yield line.split(sep)
vocabulary = set(w2v_model.vocab)
relevant_words = [(word, count) for (word, count) in read_1w_corpus('/Users/pbaumgartner/data/count_1w.txt') if word in vocabulary][:100000]
model_reduced = w2v_model[[w[0] for w in relevant_words]]
embedding = umap.UMAP(random_state=666).fit_transform(model_reduced)
d = pd.DataFrame(embedding, columns=['c1', 'c2'])
d['word'] = [w[0] for w in relevant_words]
d['count'] = [int(w[1]) for w in relevant_words]
d['log_count'] = d['count'].apply(log10)
def build_tooltip(row):
full_string = ['<b>Word:</b> ', row['word'],
'<br>',
'<b>Count:</b> ', "{:,}".format((row['count'])),
'<br>',
'<b>Magnitude:</b> ', str(round(row['log_count']))]
return ''.join(full_string)
d['tooltip'] = d.apply(build_tooltip, axis=1)
trace = go.Scattergl(
x = d['c1'],
y = d['c2'],
name = 'Embedding',
mode = 'markers',
marker = dict(
color = d['log_count'],
colorscale='Viridis',
size = 6,
line = dict(
width = 0.5,
),
opacity=0.75
),
text=d['tooltip']
)
layout = dict(title = 'Word2Vec Google News- 2D UMAP Embeddings',
yaxis = dict(zeroline = False),
xaxis = dict(zeroline = False),
hovermode = 'closest'
)
fig = go.Figure(data=[trace], layout=layout)
chart = plotly.offline.plot(fig, filename='w2v-umap.html')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment