Forked from quadrismegistus/gensim_word2vec_procrustes_align.py
Last active
June 6, 2024 18:00
-
-
Save tangert/106822a0f56f8308db3f1d77be2c7942 to your computer and use it in GitHub Desktop.
Function to align any number of word2vec models using Procrustes matrix alignment.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code originally ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <[email protected]>. | |
def align_gensim_models(models, words=None): | |
""" | |
Returns the aligned/intersected models from a list of gensim word2vec models. | |
Generalized from original two-way intersection as seen above. | |
Also updated to work with the most recent version of gensim | |
Requires reduce from functools | |
In order to run this, make sure you run 'model.init_sims()' for each model before you input them for alignment. | |
############################################## | |
ORIGINAL DESCRIPTION | |
############################################## | |
Only the shared vocabulary between them is kept. | |
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well. | |
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2). | |
These indices correspond to the new syn0 and syn0norm objects in both gensim models: | |
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0 | |
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2 | |
The .vocab dictionary is also updated for each model, preserving the count but updating the index. | |
""" | |
# Get the vocab for each model | |
vocabs = [set(m.wv.vocab.keys()) for m in models] | |
# Find the common vocabulary | |
common_vocab = reduce((lambda vocab1,vocab2: vocab1&vocab2), vocabs) | |
if words: common_vocab&=set(words) | |
# If no alignment necessary because vocab is identical... | |
# This was generalized from: | |
# if not vocab_m1-common_vocab and not vocab_m2-common_vocab and not vocab_m3-common_vocab: | |
# return (m1,m2,m3) | |
if all(not vocab-common_vocab for vocab in vocabs): | |
print("All identical!") | |
return models | |
# Otherwise sort by frequency (summed for both) | |
common_vocab = list(common_vocab) | |
common_vocab.sort(key=lambda w: sum([m.wv.vocab[w].count for m in models]),reverse=True) | |
# Then for each model... | |
for m in models: | |
# Replace old vectors_norm array with new one (with common vocab) | |
indices = [m.wv.vocab[w].index for w in common_vocab] | |
old_arr = m.wv.vectors_norm | |
new_arr = np.array([old_arr[index] for index in indices]) | |
m.wv.vectors_norm = m.wv.syn0 = new_arr | |
# Replace old vocab dictionary with new one (with common vocab) | |
# and old index2word with new one | |
m.wv.index2word = common_vocab | |
old_vocab = m.wv.vocab | |
new_vocab = {} | |
for new_index,word in enumerate(common_vocab): | |
old_vocab_obj=old_vocab[word] | |
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index, count=old_vocab_obj.count) | |
m.wv.vocab = new_vocab | |
return models |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for sharing this. Has there been any method developed to align matrices while still maintaining unique words?