Forked from quadrismegistus/gensim_word2vec_procrustes_align.py
Last active
June 3, 2024 01:54
-
-
Save zhicongchen/9e23d5c3f1e5b1293b16133485cd17d8 to your computer and use it in GitHub Desktop.
Code for aligning two gensim word2vec models using Procrustes matrix alignment (updated for compatibility with Gensim 4.0 API). The code is modified from https://gist.github.com/quadrismegistus/09a93e219a6ffc4f216fb85235535faf, which is originally ported from HistWords by William Hamilton: https://github.com/williamleif/histwords
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def smart_procrustes_align_gensim(base_embed, other_embed, words=None): | |
""" | |
Original script: https://gist.github.com/quadrismegistus/09a93e219a6ffc4f216fb85235535faf | |
Procrustes align two gensim word2vec models (to allow for comparison between same word across models). | |
Code ported from HistWords <https://github.com/williamleif/histwords> by William Hamilton <[email protected]>. | |
First, intersect the vocabularies (see `intersection_align_gensim` documentation). | |
Then do the alignment on the other_embed model. | |
Replace the other_embed model's syn0 and syn0norm numpy matrices with the aligned version. | |
Return other_embed. | |
If `words` is set, intersect the two models' vocabulary with the vocabulary in words (see `intersection_align_gensim` documentation). | |
""" | |
# patch by Richard So [https://twitter.com/richardjeanso) (thanks!) to update this code for new version of gensim | |
# base_embed.init_sims(replace=True) | |
# other_embed.init_sims(replace=True) | |
# make sure vocabulary and indices are aligned | |
in_base_embed, in_other_embed = intersection_align_gensim(base_embed, other_embed, words=words) | |
# get the (normalized) embedding matrices | |
base_vecs = in_base_embed.wv.get_normed_vectors() | |
other_vecs = in_other_embed.wv.get_normed_vectors() | |
# just a matrix dot product with numpy | |
m = other_vecs.T.dot(base_vecs) | |
# SVD method from numpy | |
u, _, v = np.linalg.svd(m) | |
# another matrix operation | |
ortho = u.dot(v) | |
# Replace original array with modified one, i.e. multiplying the embedding matrix by "ortho" | |
other_embed.wv.vectors = (other_embed.wv.vectors).dot(ortho) | |
return other_embed | |
def intersection_align_gensim(m1, m2, words=None): | |
""" | |
Intersect two gensim word2vec models, m1 and m2. | |
Only the shared vocabulary between them is kept. | |
If 'words' is set (as list or set), then the vocabulary is intersected with this list as well. | |
Indices are re-organized from 0..N in order of descending frequency (=sum of counts from both m1 and m2). | |
These indices correspond to the new syn0 and syn0norm objects in both gensim models: | |
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0 | |
-- you can find the index of any word on the .index2word list: model.index2word.index(word) => 2 | |
The .vocab dictionary is also updated for each model, preserving the count but updating the index. | |
""" | |
# Get the vocab for each model | |
vocab_m1 = set(m1.wv.index_to_key) | |
vocab_m2 = set(m2.wv.index_to_key) | |
# Find the common vocabulary | |
common_vocab = vocab_m1 & vocab_m2 | |
if words: common_vocab &= set(words) | |
# If no alignment necessary because vocab is identical... | |
if not vocab_m1 - common_vocab and not vocab_m2 - common_vocab: | |
return (m1,m2) | |
# Otherwise sort by frequency (summed for both) | |
common_vocab = list(common_vocab) | |
common_vocab.sort(key=lambda w: m1.wv.get_vecattr(w, "count") + m2.wv.get_vecattr(w, "count"), reverse=True) | |
# print(len(common_vocab)) | |
# Then for each model... | |
for m in [m1, m2]: | |
# Replace old syn0norm array with new one (with common vocab) | |
indices = [m.wv.key_to_index[w] for w in common_vocab] | |
old_arr = m.wv.vectors | |
new_arr = np.array([old_arr[index] for index in indices]) | |
m.wv.vectors = new_arr | |
# Replace old vocab dictionary with new one (with common vocab) | |
# and old index2word with new one | |
new_key_to_index = {} | |
new_index_to_key = [] | |
for new_index, key in enumerate(common_vocab): | |
new_key_to_index[key] = new_index | |
new_index_to_key.append(key) | |
m.wv.key_to_index = new_key_to_index | |
m.wv.index_to_key = new_index_to_key | |
print(len(m.wv.key_to_index), len(m.wv.vectors)) | |
return (m1,m2) |
I ended up using https://github.com/theochem/procrustes instead. Something like this:
from procrustes import rotational
common_words = sorted(CURRENT_WORDS.intersection(base_words))
print(f" Common words: {len(common_words)}")
common_words_embeddings_base = np.array([base_embeddings[word] for word in common_words])
common_words_embeddings_current = np.array([current_embeddings[word] for word in common_words])
# find the rotation matrix using orthogonal procrustes
rotation_matrix = rotational(common_words_embeddings_base, common_words_embeddings_current)
# apply the rotation matrix to the embeddings in words old
base_words_embeddings_rotated = rotation_matrix.new_a
rotated_model = KeyedVectors(300)
rotated_model.add_vectors(common_words, base_words_embeddings_rotated)
rotated_model.save("aligned.kv")
# Now release the memory and load the aligned vectors again
@estebarb Thank you so much for sharing!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I met the same issue. Any guide would be appreciated.