-
-
Save quadrismegistus/09a93e219a6ffc4f216fb85235535faf to your computer and use it in GitHub Desktop.
def smart_procrustes_align_gensim(base_embed, other_embed, words=None): | |
"""Procrustes align two gensim word2vec models (to allow for comparison between same word across models). | |
Code ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <[email protected]>. | |
(With help from William. Thank you!) | |
First, intersect the vocabularies (see `intersection_align_gensim` documentation). | |
Then do the alignment on the other_embed model. | |
Replace the other_embed model's syn0 and syn0norm numpy matrices with the aligned version. | |
Return other_embed. | |
If `words` is set, intersect the two models' vocabulary with the vocabulary in words (see `intersection_align_gensim` documentation). | |
""" | |
# patch by Richard So [https://twitter.com/richardjeanso) (thanks!) to update this code for new version of gensim | |
base_embed.init_sims() | |
other_embed.init_sims() | |
# make sure vocabulary and indices are aligned | |
in_base_embed, in_other_embed = intersection_align_gensim(base_embed, other_embed, words=words) | |
# get the embedding matrices | |
base_vecs = in_base_embed.syn0norm | |
other_vecs = in_other_embed.syn0norm | |
# just a matrix dot product with numpy | |
m = other_vecs.T.dot(base_vecs) | |
# SVD method from numpy | |
u, _, v = np.linalg.svd(m) | |
# another matrix operation | |
ortho = u.dot(v) | |
# Replace original array with modified one | |
# i.e. multiplying the embedding matrix (syn0norm)by "ortho" | |
other_embed.syn0norm = other_embed.syn0 = (other_embed.syn0norm).dot(ortho) | |
return other_embed | |
def intersection_align_gensim(m1,m2, words=None): | |
""" | |
Intersect two gensim word2vec models, m1 and m2. | |
Only the shared vocabulary between them is kept. | |
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well. | |
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2). | |
These indices correspond to the new syn0 and syn0norm objects in both gensim models: | |
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0 | |
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2 | |
The .vocab dictionary is also updated for each model, preserving the count but updating the index. | |
""" | |
# Get the vocab for each model | |
vocab_m1 = set(m1.vocab.keys()) | |
vocab_m2 = set(m2.vocab.keys()) | |
# Find the common vocabulary | |
common_vocab = vocab_m1&vocab_m2 | |
if words: common_vocab&=set(words) | |
# If no alignment necessary because vocab is identical... | |
if not vocab_m1-common_vocab and not vocab_m2-common_vocab: | |
return (m1,m2) | |
# Otherwise sort by frequency (summed for both) | |
common_vocab = list(common_vocab) | |
common_vocab.sort(key=lambda w: m1.vocab[w].count + m2.vocab[w].count,reverse=True) | |
# Then for each model... | |
for m in [m1,m2]: | |
# Replace old syn0norm array with new one (with common vocab) | |
indices = [m.vocab[w].index for w in common_vocab] | |
old_arr = m.syn0norm | |
new_arr = np.array([old_arr[index] for index in indices]) | |
m.syn0norm = m.syn0 = new_arr | |
# Replace old vocab dictionary with new one (with common vocab) | |
# and old index2word with new one | |
m.index2word = common_vocab | |
old_vocab = m.vocab | |
new_vocab = {} | |
for new_index,word in enumerate(common_vocab): | |
old_vocab_obj=old_vocab[word] | |
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count) | |
m.vocab = new_vocab | |
return (m1,m2) |
I guess, you could also just use scipy.linalg.orthogonal_procrustes
?
Line 64 brings up the following error. Any idea why?
TypeError: 'NoneType' object has no attribute 'getitem'
m.wv.syn0norm has no value somehow.
Yes. You need to perform the l2 normalization before applying thus routine. This is done by calling init_sims().
None of the matrices here is shifted to the origin, right? Yet, I found this shifting done in some explanations of Procrustes analysis, e.g. here. Is the shifting omitted on purpose, perhaps because it has no effect on the outcome or cosine?
Thanks a lot for this code. I have 5 word2vec models (i.e. for 2011, 2012, 2013, 2014, 2015) trained and I would like to combine them using this code.
Do we have to combine them in the chronological order? i.e. combine 2011 and 2012 -> get combined model 2011_2012
Combine 2011_2012 and 2013 -> get combined model 2011_2012_2013 and so on...
Please kindly correct me if I am wrong?
Hey! Thank you so much for this code. I used this in a NLP project of mine where I am comparing the same word across religious texts. I am forking this and uploading a generalized version that works with any number of models, inputed as an array.
Here is the fork with the updated code: https://gist.github.com/tangert/106822a0f56f8308db3f1d77be2c7942
Can you tell me how to align more than 2 word2vec models to each other so that the words can be compared in different models?
Here is an updated version for gensim 4.0 API: https://gist.github.com/zhicongchen/9e23d5c3f1e5b1293b16133485cd17d8
Note that for the new Gensim versions, calls for
.index2word
,.vocab
,.syn0
and.syn0norm
should be replaced with.wv.index2word
,.wv.vocab
,.wv.syn0
and.wv.syn0norm
respectively.