Last active
August 29, 2015 14:24
-
-
Save scottlingran/a3097f1f3c417764f31c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import six | |
from modelfile import model | |
# Inputs | |
positive = [] | |
negative = [] | |
# add weights for each word, if not already present; | |
# default to 1.0 for positive and -1.0 for negative words | |
# Produces list of word tuples, with weights | |
positive = [(word, 1.0) if isinstance(word, six.string_types + (numpy.ndarray,)) else word for word in positive] | |
negative = [(word, -1.0) if isinstance(word, six.string_types + (numpy.ndarray,)) else word for word in negative] | |
# compute the weighted average of all words | |
all_words = set() | |
mean = [] | |
# Loop through combined lists | |
for word, weight in positive + negative: | |
if isinstance(word, numpy.ndarray): | |
mean.append(weight * word) | |
elif word in model.vocab: | |
word_index = model.vocab[word].index | |
mean.append(weight * model.syn0norm[word_index]) | |
all_words.add(word_index) | |
else: | |
raise KeyError("word '%s' not in vocabulary" % word) | |
if not mean: | |
raise ValueError("cannot compute similarity with no input") | |
mean = matutils.unitvec(numpy.array(mean).mean(axis=0)).astype(numpy.float32) | |
# NOTE: SLOW 1.5s | |
dists = numpy.dot(model.syn0norm, mean) | |
# NOTE: SLOW 0.5s | |
best = numpy.argsort(dists)[::-1][:topn + len(all_words)] | |
# ignore (don't return) words from the input | |
result = [(model.index2word[sim], float(dists[sim])) for sim in best if sim not in all_words] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment