Skip to content

Instantly share code, notes, and snippets.

@lgmoneda
Created April 10, 2023 23:29
Show Gist options
  • Save lgmoneda/f54575eebaa8932ca926f5d0526e8a31 to your computer and use it in GitHub Desktop.
Save lgmoneda/f54575eebaa8932ca926f5d0526e8a31 to your computer and use it in GitHub Desktop.
Averaging embeddings of a document with many sentences
import os
import re
import numpy as np
import pandas as pd
import torch
import warnings
from adjustText import adjust_text
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')
MAX_SEQUENCE_LENGTH = {"all-mpnet-base-v2": 384,
"all-MiniLM-L6-v2": 256}
def adjust_sentences(strings: list[str], model_name: str) -> list[str]:
# Create an empty list to store the concatenated strings
concatenated_strings = []
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/{}'.format(model_name))
n = MAX_SEQUENCE_LENGTH[model_name]
# Keep track of the current string being concatenated
current_string = ""
# Go through each string in the list
for string in strings:
# If the current string plus the next string is shorter than or equal to n, concatenate the two strings
current_string_tokens = tokenizer(current_string, padding=True, truncation=True, return_tensors='pt')
string_tokens = tokenizer(string, padding=True, truncation=True, return_tensors='pt')
if len(current_string_tokens["input_ids"][0]) + len(string_tokens["input_ids"][0]) <= n:
current_string += string
# Otherwise, add the current string to the list of concatenated strings and reset the current string to the next string
else:
if len(current_string) > 0:
concatenated_strings.append(current_string)
current_string = string
# Add the final string to the list of concatenated strings
concatenated_strings.append(current_string)
return concatenated_strings
def get_embeddings_for_text(text, model_name):
"""
Split texts into sentences and get embeddings for each sentence.
The final embeddings is the mean of all sentence embeddings.
:param text: str. Input text.
:return: np.array. Embeddings.
"""
sentences = list(set(re.findall('[^!?。.?!]+[!?。.?!]?', text)))
token_normalized_sentences = adjust_sentences(sentences, model_name)
embedder = SentenceTransformer(model_name)
return np.mean(
embedder.encode(
token_normalized_sentences
), axis=0)
def get_embeddings_for_doc_corpus(corpus, model_name):
embeddings = []
for doc in tqdm(corpus):
embeddings.append(get_embeddings_for_text(doc, model_name))
embeddings = np.array(embeddings)
embeddings = torch.from_numpy(embeddings).float()
return embeddings
@randomwangran
Copy link

cool

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment