Last active
March 26, 2018 18:50
-
-
Save hans/46788eed5669313c54f4 to your computer and use it in GitHub Desktop.
Generate embeddings for rare words in a document by averaging the embeddings of associated context words. Find nearest neighbors of these embeddings to evaluate their quality.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter, defaultdict | |
import itertools | |
import os | |
import random | |
import re | |
import numpy as np | |
EMBEDDING_FILE = "/u/nlp/data/depparser/nn/data/embeddings/en-cw.txt" | |
EMBEDDING_SERIALIZED = "embeddings.npz" | |
STOPWORDS_FILE = "/u/nlp/data/gaz/stopwords" | |
#CORPUS_FILE = "GENIA.raw.form.txt" | |
CORPUS_FILE = "train-wsj-0-18.raw" | |
PUNCT_REGEX = re.compile(r"^[[\]()-.,]+$") | |
def load_embeddings(f_stream): | |
words = [] | |
embs = [] | |
for line in f_stream: | |
fields = line.strip().split() | |
word = fields[0] | |
vals = [float(x) for x in fields[1:]] | |
words.append(word) | |
embs.append(np.array(vals)) | |
dict = {word: i for i, word in enumerate(words)} | |
return dict, np.array(embs) | |
def load_corpus(f_stream): | |
sentences = [] | |
freqs = Counter() | |
for line in f_stream: | |
tokens = line.strip().replace('.', ' .').replace(',', ' ,').split() | |
for token in tokens: | |
freqs[token] += 1 | |
sentences.append(tokens) | |
return sentences, freqs | |
def load_corpus_with_contexts(f_stream, window=5, stopwords=frozenset()): | |
sentences = [] | |
freqs = Counter() | |
contexts = defaultdict(list) | |
for line in f_stream: | |
tokens = line.strip().replace('.', ' .').replace(',', ' ,').split() | |
for i, token in enumerate(tokens): | |
freqs[token] += 1 | |
context = [ctx_word for ctx_word in tokens[max(0, i - window):i + window] | |
if ctx_word != token and ctx_word not in stopwords and not PUNCT_REGEX.match(ctx_word)] | |
contexts[token].extend(context) | |
sentences.append(tokens) | |
return sentences, freqs, contexts | |
def get_contexts(word, sentences, window=5, stopwords=frozenset()): | |
context_words = [] | |
for sentence in sentences: | |
idxs = [i for i, other_word in enumerate(sentence) if word == other_word] | |
for idx in idxs: | |
idx_context = [ctx_word for ctx_word in sentence[max(0, idx - window):idx + window] | |
if ctx_word != word and ctx_word not in stopwords and not PUNCT_REGEX.match(ctx_word)] | |
context_words.extend(idx_context) | |
return context_words | |
def avg_word_embeddings(words, dict, embs): | |
avg = np.zeros(embs.shape[1]) | |
for word in words: | |
word = word.lower() | |
try: | |
avg += embs[dict[word.lower()]] | |
except KeyError: pass | |
avg /= float(len(words)) | |
return avg | |
def nearest_neighbors(x, embeddings, n=5): | |
x /= np.linalg.norm(x) | |
dists = -embeddings.dot(x) | |
return np.argsort(dists)[:n] | |
def nearest_neighbor_words(x, dict, embs, rev_dict=None, n=5): | |
if rev_dict is None: | |
rev_dict = {v: k for k, v in dict.iteritems()} | |
return [rev_dict[id] for id in nearest_neighbors(x, embs)] | |
def averaged_nn_simple(word_list, dict, embs, stopwords=frozenset()): | |
"""Generate embeddings for provided words by averaging embeddings | |
of context words. Yields pairs of the form | |
(word, freq, embedding) | |
""" | |
with open(CORPUS_FILE, 'r') as corpus_f: | |
sentences, word_freqs = list(load_corpus(corpus_f)) | |
if isinstance(word_list, int): | |
word_list = random.sample(word_freqs.keys(), word_list) | |
for word in word_list: | |
ctxs = get_contexts(word, sentences, stopwords=stopwords) | |
yield word, word_freqs[word], avg_word_embeddings(ctxs, dict, embs) | |
def averaged_nn_nn(word_list, dict, embs, freq_threshold=10, window=5, stopwords=frozenset(), rev_dict=None): | |
"""Generate embeddings for provided words by picking word embeddings | |
that match the context embeddings of the provided words. Yields pairs | |
of the form | |
(word, freq, embedding) | |
""" | |
with open(CORPUS_FILE, 'r') as corpus_f: | |
sentences, word_freqs, contexts = load_corpus_with_contexts(corpus_f, window=window, | |
stopwords=stopwords) | |
if isinstance(word_list, int): | |
word_list = random.sample(word_freqs.keys(), word_list) | |
if rev_dict is None: | |
rev_dict = {v: k for k, v in dict.iteritems()} | |
# Compute context embeddings for all words in the dictionary | |
context_embeddings = np.array([avg_word_embeddings(ctxs, dict, embs) | |
for word, ctxs in contexts.iteritems()]) | |
# Normalize context embeddings | |
context_embeddings /= np.linalg.norm(context_embeddings, axis=1)[:, np.newaxis] | |
for word in word_list: | |
# Compute context embedding of unknown word | |
ctxs = contexts[word] | |
unk_context_embedding = avg_word_embeddings(ctxs, dict, embs) | |
unk_context_embedding /= np.linalg.norm(unk_context_embedding) | |
# Find distance to context words | |
dists = -context_embeddings.dot(unk_context_embedding) | |
# Find context embedding and associated word (with word freq threshold) | |
best_idx = next(idx for idx in dists.argsort() if word_freqs[rev_dict[idx]] > freq_threshold) | |
# OK, return word embedding associated with this context embedding | |
emb = embs[best_idx] | |
yield word, word_freqs[word], emb | |
if __name__ == '__main__': | |
if os.path.exists(EMBEDDING_SERIALIZED): | |
ser = np.load(EMBEDDING_SERIALIZED) | |
dict, embs = ser['dict'][()], ser['embs'] | |
else: | |
with open(EMBEDDING_FILE, 'r') as emb_f: | |
dict, embs = load_embeddings(emb_f) | |
# Normalize embeddings | |
embs /= np.linalg.norm(embs, axis=1)[:, np.newaxis] | |
np.savez(EMBEDDING_SERIALIZED, dict=dict, embs=embs) | |
print 'Loaded %i embeddings' % len(dict) | |
with open(STOPWORDS_FILE, 'r') as stopwords_f: | |
stopwords = frozenset([x.strip() for x in stopwords_f.readlines()]) | |
# Number of words to fetch per trial | |
n_words = 20 | |
# Number of neighbors to list | |
n_list = 7 | |
dim = embs.shape[1] | |
rev_dict = {v: k for k, v in dict.iteritems()} | |
nn_simple_results = list(averaged_nn_simple(n_words, dict, embs, stopwords=stopwords)) | |
word_list = [word for word, _, _ in nn_simple_results] | |
nn_nn_results = list(averaged_nn_nn(word_list, dict, embs, stopwords=stopwords)) | |
for nn_simple_result, nn_nn_result in zip(nn_simple_results, nn_nn_results): | |
word, word_freq, simple_emb = nn_simple_result | |
_, _, nn_emb = nn_nn_result | |
# Print nearest neighbors of simple averaged embedding | |
neighbors = nearest_neighbor_words(simple_emb, dict, embs, rev_dict=rev_dict, n=n_list) | |
print '%30s %5d\t%s' % (word, word_freq, ' '.join(neighbors)) | |
# Print nearest neighbors of NN averaged embedding | |
nn_neighbors = nearest_neighbor_words(nn_emb, dict, embs, rev_dict=rev_dict, n=n_list) | |
print '%30s %5d\t%s' % (word, word_freq, ' '.join(nn_neighbors)) | |
# Print nearest neighbors of a random embedding | |
rand_neighbors = nearest_neighbor_words(np.random.rand(dim), dict, embs, rev_dict=rev_dict, n=n_list) | |
print '%30s %5s\t%s' % ('', '', ' '.join(rand_neighbors)) | |
""" | |
Example output (run on PTB WSJ training data): | |
$ python avg_embeddings.py | |
Loaded 130000 embeddings | |
TRIMMING 1 ? happy jealous speechless depressed | |
mid-week windies spoils chanderpaul roode | |
Aug 41 cost reduced budget election vote | |
yankees metrodome canucks waratahs stormers | |
cultivating 1 a head easy constant power | |
cannot pre-election dampens mid-atlantic city-state | |
Knopf 4 book royal writing title reading | |
could seems appreciates dugouts seemed | |
Symbol 6 otc ; : ystem aggregator | |
dugouts brink mid-week hopes trumper | |
loyalty 28 brand market share potential buyer | |
dugouts mid-week persuasions deal-making post-cold | |
Hurwitz 1 president chief david general james | |
deserves appreciates emphasizes dictates materialises | |
plunging 7 sales money payment sale prices | |
pessimists nby clamour conferees doubters | |
intimidation 1 bombings killings arrests campaign assassination | |
saw emphasizes beyond exceeds ensures | |
Kelly/Varnell 1 west south east left come | |
dampens long-awaited post-war mid-week cross-media | |
police 59 security police emergency service aid | |
dampens chanderpaul cannot downplays execs | |
cost-efficiency 1 productivity flexibility growth consumption interdependence | |
mid-week dugouts blacklist dampens heightens | |
worriers 2 repressed refuting discussed heterodox applauded | |
bolls acpc clamour requirments fall-out | |
Money-fund 3 wages revenues assets incomes receipts | |
olein dampens stearin bradies bm&f | |
cocky 1 denying depriving accusing eliminating accepting | |
hurts dampens downplays scotched defers | |
Laurel 8 president bank agent governor party | |
deserves deserved downplays resounding exuded | |
minor-sport 1 games scores drivers students points | |
mid-week littlejohn dampens long-awaited dugouts | |
graphic 1 firms entrepreneurs photographers economists scientists | |
appreciates promises blacklist hustle imbroglio | |
campaign 108 hearing news special advertising public | |
mid-week reckons conferees loog dugouts | |
Manhattan 70 council center district school law | |
deserve deserves dampens enduring clear-cut | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment