Created
April 10, 2023 23:29
-
-
Save lgmoneda/f54575eebaa8932ca926f5d0526e8a31 to your computer and use it in GitHub Desktop.
Averaging embeddings of a document with many sentences
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import re | |
import numpy as np | |
import pandas as pd | |
import torch | |
import warnings | |
from adjustText import adjust_text | |
from sentence_transformers import SentenceTransformer | |
from sklearn.manifold import TSNE | |
from tqdm import tqdm | |
from transformers import AutoModel, AutoTokenizer | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
warnings.filterwarnings('ignore') | |
MAX_SEQUENCE_LENGTH = {"all-mpnet-base-v2": 384, | |
"all-MiniLM-L6-v2": 256} | |
def adjust_sentences(strings: list[str], model_name: str) -> list[str]: | |
# Create an empty list to store the concatenated strings | |
concatenated_strings = [] | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/{}'.format(model_name)) | |
n = MAX_SEQUENCE_LENGTH[model_name] | |
# Keep track of the current string being concatenated | |
current_string = "" | |
# Go through each string in the list | |
for string in strings: | |
# If the current string plus the next string is shorter than or equal to n, concatenate the two strings | |
current_string_tokens = tokenizer(current_string, padding=True, truncation=True, return_tensors='pt') | |
string_tokens = tokenizer(string, padding=True, truncation=True, return_tensors='pt') | |
if len(current_string_tokens["input_ids"][0]) + len(string_tokens["input_ids"][0]) <= n: | |
current_string += string | |
# Otherwise, add the current string to the list of concatenated strings and reset the current string to the next string | |
else: | |
if len(current_string) > 0: | |
concatenated_strings.append(current_string) | |
current_string = string | |
# Add the final string to the list of concatenated strings | |
concatenated_strings.append(current_string) | |
return concatenated_strings | |
def get_embeddings_for_text(text, model_name): | |
""" | |
Split texts into sentences and get embeddings for each sentence. | |
The final embeddings is the mean of all sentence embeddings. | |
:param text: str. Input text. | |
:return: np.array. Embeddings. | |
""" | |
sentences = list(set(re.findall('[^!?。.?!]+[!?。.?!]?', text))) | |
token_normalized_sentences = adjust_sentences(sentences, model_name) | |
embedder = SentenceTransformer(model_name) | |
return np.mean( | |
embedder.encode( | |
token_normalized_sentences | |
), axis=0) | |
def get_embeddings_for_doc_corpus(corpus, model_name): | |
embeddings = [] | |
for doc in tqdm(corpus): | |
embeddings.append(get_embeddings_for_text(doc, model_name)) | |
embeddings = np.array(embeddings) | |
embeddings = torch.from_numpy(embeddings).float() | |
return embeddings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
cool