Last active
October 9, 2024 13:54
-
-
Save BramVanroy/17ac6a736d09ff9507ac75b193c48430 to your computer and use it in GitHub Desktop.
Getting word embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, field | |
import torch | |
from torch import LongTensor, Tensor | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
PreTrainedModel, | |
PreTrainedTokenizer, | |
BatchEncoding, | |
) | |
@dataclass | |
class Embedder: | |
model_name: str | |
layers: list[int] | None = None | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
verbose: bool = False | |
tokenizer: PreTrainedTokenizer | None = field(default=None, init=False) | |
model: PreTrainedModel | None = field(default=None, init=False) | |
def __post_init__(self): | |
# Get embedding from last layer by default | |
self.layers = [-1] if self.layers is None else self.layers | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModel.from_pretrained( | |
self.model_name, output_hidden_states=True | |
).to(self.device) | |
@torch.inference_mode() | |
def _get_hidden_states( | |
self, encoded: dict[str, LongTensor] | BatchEncoding, token_ids_word: list[int] | |
) -> Tensor: | |
""" | |
Push input IDs through model. Get hidden states from requested layers and sum them if multiple layer | |
representations are requested. Only select the tokens that constitute the requested word. And finally | |
average the word tokens to get a single word vector. | |
""" | |
output = self.model(**encoded) | |
# Get all hidden states | |
states = output.hidden_states | |
# Stack and sum all requested layers | |
if len(self.layers) > 1: | |
output = torch.stack([states[i] for i in self.layers]).sum(0).squeeze() | |
else: | |
output = states[self.layers[0]].squeeze() | |
# Only select the tokens that constitute the requested word | |
word_tokens_output = output[token_ids_word] | |
return word_tokens_output.mean(dim=0) | |
def _get_embedding_from_word_idx(self, sentence: str, word_idx: int) -> Tensor: | |
"""Get a word vector by first tokenizing the input sentence, getting all token idxs | |
that make up the word of interest, and then `get_hidden_states`.""" | |
encoded = self.tokenizer(sentence, return_tensors="pt").to(self.model.device) | |
# get all token idxs that belong to the word of interest | |
token_ids_word = [ | |
position_idx | |
for position_idx, idx in enumerate(encoded.word_ids()) | |
if idx == word_idx | |
] | |
if self.verbose: | |
# Sanity check that the word is correctly tokenized | |
subword_toks = [encoded.tokens()[idx] for idx in token_ids_word] | |
print(f"Subword tokens belonging to word: {subword_toks}") | |
return self._get_hidden_states(encoded, token_ids_word) | |
def _get_word_idx(self, sentence: str, word: str) -> int: | |
""" | |
Get the index of a WORD in a sentence. | |
""" | |
words = sentence.split(" ") | |
try: | |
word_idx = words.index(word) | |
except ValueError as exc: | |
raise ValueError( | |
f"Word '{word}' not found in sentence '{sentence}'." | |
f" Note that the word index can only be found considering white-space split tokens." | |
) from exc | |
if self.verbose: | |
print(f"Word '{word}' has word index {word_idx} in the sentence.") | |
return word_idx | |
def get_word_embedding(self, sentence: str, word: str) -> Tensor: | |
""" | |
Given a sentence and a word in that sentence, return the word embedding. | |
""" | |
word_idx = self._get_word_idx(sentence, word) | |
return self._get_embedding_from_word_idx(sentence, word_idx) | |
if __name__ == "__main__": | |
embedder = Embedder("pdelobelle/robbert-v2-dutch-base", verbose=True, device="cuda") | |
sent = "Ik heb gisteren een koekje gegeten." | |
w = "koekje" | |
koek_embed = embedder.get_word_embedding(sent, w) | |
sent = "Ik heb gisteren een appel gegeten." | |
w = "appel" | |
appel_embed = (w, embedder.get_word_embedding(sent, w)) | |
sent = "Ik heb gisteren een olifant gegeten." | |
w = "olifant" | |
olifant_embed = (w, embedder.get_word_embedding(sent, w)) | |
sent = "Ik heb gisteren een oma gegeten." | |
w = "oma" | |
oma_embed = (w, embedder.get_word_embedding(sent, w)) | |
sent = "Ik heb gisteren een herfst gegeten." | |
w = "herfst" | |
herfst_embed = (w, embedder.get_word_embedding(sent, w)) | |
if embedder.verbose: | |
print() | |
embeds = [appel_embed, olifant_embed, oma_embed, herfst_embed] | |
for w, embed in embeds: | |
sim = torch.nn.functional.cosine_similarity(koek_embed, embed, dim=0) | |
print(f"Similarity between 'koekje' and '{w}': {sim}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Printed output: