Last active
February 22, 2022 23:46
-
-
Save lgalke/febaaa1313d9c11f3bc8240defed8390 to your computer and use it in GitHub Desktop.
Word Embedding Postprocessing: All but the top
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
All-but-the-Top: Simple and Effective Postprocessing for Word Representations | |
Paper: https://arxiv.org/abs/1702.01417 | |
Last Updated: Fri 15 Nov 2019 11:47:00 AM CET | |
**Prior version had serious issues, please excuse any inconveniences.** | |
""" | |
import numpy as np | |
from sklearn.decomposition import PCA | |
def all_but_the_top(v, D): | |
""" | |
Arguments: | |
:v: word vectors of shape (n_words, n_dimensions) | |
:D: number of principal components to subtract | |
""" | |
# 1. Subtract mean vector | |
v_tilde = v - np.mean(v, axis=0) | |
# 2. Compute the first `D` principal components | |
# on centered embedding vectors | |
u = PCA(n_components=D).fit(v_tilde).components_ # [D, emb_size] | |
# Subtract first `D` principal components | |
# [vocab_size, emb_size] @ [emb_size, D] @ [D, emb_size] -> [vocab_size, emb_size] | |
return v_tilde - (v @ u.T @ u) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@lgalke You are correct, it should be mean centered embeddings that I need to subtract from. I'll fix that, thanks.