Skip to content

Instantly share code, notes, and snippets.

@BramVanroy
Last active October 9, 2024 13:54
Show Gist options
  • Save BramVanroy/17ac6a736d09ff9507ac75b193c48430 to your computer and use it in GitHub Desktop.
Save BramVanroy/17ac6a736d09ff9507ac75b193c48430 to your computer and use it in GitHub Desktop.
Getting word embeddings
from dataclasses import dataclass, field
import torch
from torch import LongTensor, Tensor
from transformers import (
AutoTokenizer,
AutoModel,
PreTrainedModel,
PreTrainedTokenizer,
BatchEncoding,
)
@dataclass
class Embedder:
model_name: str
layers: list[int] | None = None
device: str = "cuda" if torch.cuda.is_available() else "cpu"
verbose: bool = False
tokenizer: PreTrainedTokenizer | None = field(default=None, init=False)
model: PreTrainedModel | None = field(default=None, init=False)
def __post_init__(self):
# Get embedding from last layer by default
self.layers = [-1] if self.layers is None else self.layers
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(
self.model_name, output_hidden_states=True
).to(self.device)
@torch.inference_mode()
def _get_hidden_states(
self, encoded: dict[str, LongTensor] | BatchEncoding, token_ids_word: list[int]
) -> Tensor:
"""
Push input IDs through model. Get hidden states from requested layers and sum them if multiple layer
representations are requested. Only select the tokens that constitute the requested word. And finally
average the word tokens to get a single word vector.
"""
output = self.model(**encoded)
# Get all hidden states
states = output.hidden_states
# Stack and sum all requested layers
if len(self.layers) > 1:
output = torch.stack([states[i] for i in self.layers]).sum(0).squeeze()
else:
output = states[self.layers[0]].squeeze()
# Only select the tokens that constitute the requested word
word_tokens_output = output[token_ids_word]
return word_tokens_output.mean(dim=0)
def _get_embedding_from_word_idx(self, sentence: str, word_idx: int) -> Tensor:
"""Get a word vector by first tokenizing the input sentence, getting all token idxs
that make up the word of interest, and then `get_hidden_states`."""
encoded = self.tokenizer(sentence, return_tensors="pt").to(self.model.device)
# get all token idxs that belong to the word of interest
token_ids_word = [
position_idx
for position_idx, idx in enumerate(encoded.word_ids())
if idx == word_idx
]
if self.verbose:
# Sanity check that the word is correctly tokenized
subword_toks = [encoded.tokens()[idx] for idx in token_ids_word]
print(f"Subword tokens belonging to word: {subword_toks}")
return self._get_hidden_states(encoded, token_ids_word)
def _get_word_idx(self, sentence: str, word: str) -> int:
"""
Get the index of a WORD in a sentence.
"""
words = sentence.split(" ")
try:
word_idx = words.index(word)
except ValueError as exc:
raise ValueError(
f"Word '{word}' not found in sentence '{sentence}'."
f" Note that the word index can only be found considering white-space split tokens."
) from exc
if self.verbose:
print(f"Word '{word}' has word index {word_idx} in the sentence.")
return word_idx
def get_word_embedding(self, sentence: str, word: str) -> Tensor:
"""
Given a sentence and a word in that sentence, return the word embedding.
"""
word_idx = self._get_word_idx(sentence, word)
return self._get_embedding_from_word_idx(sentence, word_idx)
if __name__ == "__main__":
embedder = Embedder("pdelobelle/robbert-v2-dutch-base", verbose=True, device="cuda")
sent = "Ik heb gisteren een koekje gegeten."
w = "koekje"
koek_embed = embedder.get_word_embedding(sent, w)
sent = "Ik heb gisteren een appel gegeten."
w = "appel"
appel_embed = (w, embedder.get_word_embedding(sent, w))
sent = "Ik heb gisteren een olifant gegeten."
w = "olifant"
olifant_embed = (w, embedder.get_word_embedding(sent, w))
sent = "Ik heb gisteren een oma gegeten."
w = "oma"
oma_embed = (w, embedder.get_word_embedding(sent, w))
sent = "Ik heb gisteren een herfst gegeten."
w = "herfst"
herfst_embed = (w, embedder.get_word_embedding(sent, w))
if embedder.verbose:
print()
embeds = [appel_embed, olifant_embed, oma_embed, herfst_embed]
for w, embed in embeds:
sim = torch.nn.functional.cosine_similarity(koek_embed, embed, dim=0)
print(f"Similarity between 'koekje' and '{w}': {sim}")
@BramVanroy
Copy link
Author

Printed output:

Word 'koekje' has word index 4 in the sentence.
Subword tokens belonging to word: ['Ġkoekje']
Word 'appel' has word index 4 in the sentence.
Subword tokens belonging to word: ['Ġappel']
Word 'olifant' has word index 4 in the sentence.
Subword tokens belonging to word: ['Ġolifant']
Word 'oma' has word index 4 in the sentence.
Subword tokens belonging to word: ['Ġoma']
Word 'herfst' has word index 4 in the sentence.
Subword tokens belonging to word: ['Ġherfst']

Similarity between 'koekje' and 'appel': 0.9116837978363037
Similarity between 'koekje' and 'olifant': 0.8721239566802979
Similarity between 'koekje' and 'oma': 0.7907631993293762
Similarity between 'koekje' and 'herfst': 0.8304888010025024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment