Skip to content

Instantly share code, notes, and snippets.

@asehmi
Forked from pszemraj/compute_embeddings_e5.py
Created January 23, 2024 07:19
Show Gist options
  • Save asehmi/03755f25bbc06e5f0594a26041480fc8 to your computer and use it in GitHub Desktop.
Save asehmi/03755f25bbc06e5f0594a26041480fc8 to your computer and use it in GitHub Desktop.
helper script using just transformers/torch to compute text embeddings (for e5 models https://huggingface.co/intfloat/e5-base-v2 )
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from pandas import DataFrame
from typing import List, Union
from tqdm.auto import tqdm, trange
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""
Calculate the average pooling of the last hidden states of a given transformer model.
Args:
last_hidden_states (Tensor): Tensor of the last hidden states from transformer model.
attention_mask (Tensor): Tensor of attention masks for each of the input sequences.
Returns:
Tensor: A tensor representing the average pooled embeddings of the input sequences.
"""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def encode_texts(texts: List[str], tokenizer, model, device, batch_size=16) -> Tensor:
"""
Encode a list of texts using a pre-trained transformer model and a tokenizer.
Args:
texts (List[str]): List of texts to be encoded.
tokenizer (transformers.PreTrainedTokenizer): Pre-trained tokenizer.
model (transformers.PreTrainedModel): Pre-trained transformer model.
device (str): The device to which tensors will be moved. Can be either "cuda" or "cpu".
batch_size (int, optional): Size of the batches in which the input data is split. Defaults to 16.
Returns:
Tensor: A tensor of embeddings for each input text.
"""
embeddings = []
for i in trange(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_dict = tokenizer(
batch_texts,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
)
# move tensors to the configured device
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
with torch.no_grad(): # deactivate autograd engine to reduce memory usage and speed up computations
outputs = model(**batch_dict)
batch_embeddings = average_pool(
outputs.last_hidden_state, batch_dict["attention_mask"]
)
# (Optionally) normalize embeddings
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings)
embeddings = torch.cat(embeddings, dim=0)
return embeddings.cpu()
def encode_textdata(
text_data: Union[DataFrame, List[str]],
model_name: str = "intfloat/e5-base-v2",
text_col: str = "text",
device=None,
batch_size=16,
) -> Tensor:
"""
Encode text data from a DataFrame or a list using a pre-trained transformer model.
Args:
text_data (Union[DataFrame, List[str]]): DataFrame or list containing the texts to be encoded.
model_name (str, optional): Pre-trained transformer model name. Defaults to 'intfloat/e5-base-v2'.
text_col (str, optional): Column name from which to extract texts if input is a DataFrame. Defaults to 'text'.
device (str, optional): The device to which tensors will be moved. Can be either "cuda" or "cpu".
batch_size (int, optional): Size of the batches in which the input data is split. Defaults to 16.
Returns:
Tensor: A tensor of embeddings for each input text.
Raises:
AssertionError: If the input is a DataFrame and the provided column name does not exist.
"""
# Check for GPU availability and set the device accordingly
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
# Get the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model = model.eval()
# Get the input texts directly from the dataframe column
if isinstance(text_data, DataFrame):
assert text_col in text_data.columns, f"column {text_col} not found in df"
input_texts = text_data[text_col].tolist()
else:
input_texts = text_data
return encode_texts(input_texts, tokenizer, model, device, batch_size)
"""
This script contains functions for embedding text data using BERTopic and KeyBERTInspired.
It loads a pre-trained BERT model and tokenizer, and uses them to generate embeddings for text data.
The embeddings are then used to cluster the text data using BERTopic.
The script also provides a function for generating keyphrases using KeyBERTInspired.
This script is designed to be run from the command line interface (CLI):
python embed_e5_cli.py <command> <args>
"""
import logging
import os
import pathlib
import pprint as pp
import subprocess
import fire
import joblib
import plotly.io as pio
import torch
import torch.nn.functional as F
from bertopic import BERTopic
from bertopic.backend import BaseEmbedder
from bertopic.representation import KeyBERTInspired
from datasets import load_dataset
from joblib import Parallel, delayed
from nltk.corpus import stopwords
from torch import Tensor
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def check_ampere_gpu():
"""Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does."""
cmd = "nvidia-smi --query-gpu=name --format=csv,noheader"
try:
output = subprocess.check_output(cmd, shell=True, universal_newlines=True)
gpu_names = output.strip().split("\n")
except Exception as e:
logger.error(f"Error: {e}")
return
supported_gpus = ["A100", "A6000", "RTX 30"] # Add more models if needed
for gpu_name in gpu_names:
if any(supported_gpu in gpu_name for supported_gpu in supported_gpus):
torch.backends.cuda.matmul.allow_tf32 = True
logger.info(
f"{gpu_name} supports NVIDIA Ampere or later. Enabled TF32 in PyTorch."
)
else:
logger.info(f"{gpu_name} does not support NVIDIA Ampere or later.")
def drop_stopwords(summary):
updated_summary = []
for s in summary:
words = s.split()
words = [word for word in words if word.lower() not in stopwords]
updated_summary.append(" ".join(words))
return updated_summary
def encode_texts(texts: list, tokenizer, model, device, batch_size=16) -> Tensor:
embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
batch_texts = texts[i : i + batch_size]
batch_dict = tokenizer(
batch_texts,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
)
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
with torch.no_grad():
outputs = model(**batch_dict)
batch_embeddings = average_pool(
outputs.last_hidden_state, batch_dict["attention_mask"]
)
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings)
embeddings = torch.cat(embeddings, dim=0)
return embeddings.cpu()
def encode_textdata(
text_data,
model_name="intfloat/e5-base-v2",
text_col="text",
device=None,
batch_size=16,
) -> Tensor:
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model = model.eval()
if isinstance(text_data, list):
input_texts = text_data
else:
input_texts = text_data[text_col].tolist()
return encode_texts(input_texts, tokenizer, model, device, batch_size)
class CustomEmbedder(BaseEmbedder):
def __init__(
self,
model_name_or_path: str = "intfloat/e5-small-v2",
device=None,
batch_size=16,
):
super().__init__()
self.batch_size = 16
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device)
self.model = self.model.eval()
def embed(self, documents, verbose=False) -> Tensor:
embeddings = []
batch_size = self.batch_size
for i in tqdm(range(0, len(documents), batch_size), desc="Embedding documents"):
batch_texts = documents[i : i + batch_size]
batch_dict = self.tokenizer(
batch_texts,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
outputs = self.model(**batch_dict)
batch_embeddings = average_pool(
outputs.last_hidden_state, batch_dict["attention_mask"]
)
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.append(batch_embeddings)
embeddings = torch.cat(embeddings, dim=0)
return embeddings.cpu().numpy()
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def preprocess_dataset(dataset, filter_stopwords):
if filter_stopwords:
ds_train = dataset["train"].map(
lambda example: {"summary": drop_stopwords(example["summary"])},
batched=True,
)
else:
ds_train = dataset["train"]
logger.info("Not removing stopwords")
return ds_train
def preprocess_data(dataset_name, filter_stopwords):
dataset = load_dataset(dataset_name)
ds_short_name = dataset_name.split("/")[-1]
ds_train = preprocess_dataset(dataset, filter_stopwords)
return ds_train, ds_short_name
def preprocess_documents(
ds_train, docs_column, embedding_model_name, split_by_maxlen=False
):
representation_model = KeyBERTInspired()
summ_docs = ds_train[docs_column]
summ_docs = list(set(summ_docs))
emb_short_name = embedding_model_name.split("/")[-1]
split_path = pathlib.Path(f"summ_docs_split-{emb_short_name}.joblib")
if split_path.exists() and split_by_maxlen:
logger.info(f"Loading existing file: {split_path}")
summ_docs_split = joblib.load(split_path)
else:
if split_by_maxlen:
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
summ_docs_split = split_text_by_tokens_parallel(
summ_docs, tokenizer, n_jobs=max(os.cpu_count() - 4, 1)
)
joblib.dump(summ_docs_split, split_path, compress=5)
logger.info(f"Saved file: {split_path}")
else:
summ_docs_split = summ_docs
return summ_docs_split
def split_text(text, tokenizer):
encoded_text = tokenizer.encode_plus(
text,
add_special_tokens=False,
truncation=False,
padding=False,
return_tensors="pt",
max_length=10**6,
)
input_ids_list = encoded_text["input_ids"].view(-1).tolist()
chunks = [
input_ids_list[i : i + tokenizer.model_max_length]
for i in range(0, len(input_ids_list), tokenizer.model_max_length)
]
split_strings = [
tokenizer.decode(
chunk, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for chunk in chunks
]
return split_strings
def split_text_by_tokens_parallel(text_list, tokenizer, n_jobs=-1, backend="loky"):
logger.info(f"Running token splitting with n_jobs {n_jobs}, backend={backend}")
results = Parallel(n_jobs=n_jobs, backend=backend)(
delayed(split_text)(text, tokenizer)
for text in tqdm(text_list, desc="Splitting up text")
)
split_strings = [text for sublist in results for text in sublist]
return split_strings
def main(
dataset_name,
filter_stopwords=False,
docs_column: str = "target",
split_by_maxlen: bool = False,
calculate_probabilities: bool = False,
embedding_model_id: str = "intfloat/e5-base-v2",
nr_topics: int = None,
use_manual_embedder: bool = True,
min_topic_size=30,
save_html: bool = False,
):
"""
Main function for embedding E5 CLI.
Args:
dataset_name (str): Name of the dataset to use.
filter_stopwords (bool, optional): Whether to filter stopwords. Defaults to False.
docs_column (str, optional): Name of the column containing the documents. Defaults to "target".
split_by_maxlen (bool, optional): Whether to split documents by maximum length. Defaults to False.
calculate_probabilities (bool, optional): Whether to calculate probabilities. Defaults to False.
embedding_model_id (str, optional): ID of the embedding model to use. Defaults to "intfloat/e5-base-v2".
nr_topics (int, optional): Number of topics to use. Defaults to None == "auto".
use_manual_embedder (bool, optional): Whether to use a manual embedder. Defaults to True.
min_topic_size (int, optional): Minimum size of a topic. Defaults to 30.
save_html (bool, optional): Whether to save HTML. Defaults to False.
Returns:
None
"""
check_ampere_gpu()
nr_topics = nr_topics or "auto"
# Preprocess dataset
ds_train, ds_short_name = preprocess_data(
dataset_name,
filter_stopwords,
)
# Preprocess documents
summ_docs_split = preprocess_documents(
ds_train, docs_column, embedding_model_id, split_by_maxlen=split_by_maxlen
)
representation_model = KeyBERTInspired()
_kb_emb_model = (
CustomEmbedder(embedding_model_id)
if use_manual_embedder
else embedding_model_id
)
topic_model = BERTopic(
"english",
verbose=True,
calculate_probabilities=calculate_probabilities,
embedding_model=_kb_emb_model,
n_gram_range=(1, 1),
nr_topics=nr_topics,
min_topic_size=min_topic_size,
representation_model=representation_model,
)
logging.info(f"Params:\n{pp.pformat(topic_model.get_params())}")
emb_short_name = embedding_model_id.split("/")[-1]
# Encode documents
tm_out_docs, tm_out_probs = topic_model.fit_transform(summ_docs_split)
# Push to Hugging Face Hub
model_name = f"pszemraj/BERTopic-{ds_short_name}-{emb_short_name}-{docs_column}"
topic_model.push_to_hf_hub(
model_name, private=True, save_embedding_model=True, save_ctfidf=True
)
plot_base_name = f"{ds_short_name}-{emb_short_name}-{docs_column}"
if save_html:
# Visualize topics
fig_topics = topic_model.visualize_topics(
title=f"<b>Intertopic Distance Map:<br><br>{docs_column} via {emb_short_name}</b>"
)
pio.write_html(fig_topics, plot_base_name + "topics.html")
fig_docs = topic_model.visualize_documents(
summ_docs_split,
title=f"<b>Documents and Topics - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_html(fig_docs, plot_base_name + "documents.html")
fig_hierarchy = topic_model.visualize_hierarchy(
top_n_topics=30,
title=f"<b>Hierarchy - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_html(fig_hierarchy, plot_base_name + "hierarchy.html")
fig_barchart = topic_model.visualize_barchart(
top_n_topics=8,
title=f"<b>Topic Word Scores - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_html(fig_barchart, plot_base_name + "barchart.html")
else:
# Visualize topics
fig_topics = topic_model.visualize_topics(
title=f"<b>Intertopic Distance Map:<br><br>{docs_column} via {emb_short_name}</b>"
)
pio.write_image(fig_topics, plot_base_name + "topics.jpg")
fig_docs = topic_model.visualize_documents(
summ_docs_split,
title=f"<b>Documents and Topics - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_image(fig_docs, plot_base_name + "documents.jpg")
fig_hierarchy = topic_model.visualize_hierarchy(
top_n_topics=30,
title=f"<b>Hierarchy - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_image(fig_hierarchy, plot_base_name + "hierarchy.jpg")
fig_barchart = topic_model.visualize_barchart(
top_n_topics=8,
title=f"<b>Topic Word Scores - <br>{docs_column} via {emb_short_name}</b>",
)
pio.write_image(fig_barchart, plot_base_name + "barchart.jpg")
logging.info(f"Pushed model to {model_name}")
logging.info(f"Done!")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment