-
-
Save asehmi/03755f25bbc06e5f0594a26041480fc8 to your computer and use it in GitHub Desktop.
helper script using just transformers/torch to compute text embeddings (for e5 models https://huggingface.co/intfloat/e5-base-v2 )
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch import Tensor | |
from transformers import AutoTokenizer, AutoModel | |
from pandas import DataFrame | |
from typing import List, Union | |
from tqdm.auto import tqdm, trange | |
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
""" | |
Calculate the average pooling of the last hidden states of a given transformer model. | |
Args: | |
last_hidden_states (Tensor): Tensor of the last hidden states from transformer model. | |
attention_mask (Tensor): Tensor of attention masks for each of the input sequences. | |
Returns: | |
Tensor: A tensor representing the average pooled embeddings of the input sequences. | |
""" | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def encode_texts(texts: List[str], tokenizer, model, device, batch_size=16) -> Tensor: | |
""" | |
Encode a list of texts using a pre-trained transformer model and a tokenizer. | |
Args: | |
texts (List[str]): List of texts to be encoded. | |
tokenizer (transformers.PreTrainedTokenizer): Pre-trained tokenizer. | |
model (transformers.PreTrainedModel): Pre-trained transformer model. | |
device (str): The device to which tensors will be moved. Can be either "cuda" or "cpu". | |
batch_size (int, optional): Size of the batches in which the input data is split. Defaults to 16. | |
Returns: | |
Tensor: A tensor of embeddings for each input text. | |
""" | |
embeddings = [] | |
for i in trange(0, len(texts), batch_size): | |
batch_texts = texts[i : i + batch_size] | |
batch_dict = tokenizer( | |
batch_texts, | |
max_length=512, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
# move tensors to the configured device | |
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} | |
with torch.no_grad(): # deactivate autograd engine to reduce memory usage and speed up computations | |
outputs = model(**batch_dict) | |
batch_embeddings = average_pool( | |
outputs.last_hidden_state, batch_dict["attention_mask"] | |
) | |
# (Optionally) normalize embeddings | |
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
embeddings.append(batch_embeddings) | |
embeddings = torch.cat(embeddings, dim=0) | |
return embeddings.cpu() | |
def encode_textdata( | |
text_data: Union[DataFrame, List[str]], | |
model_name: str = "intfloat/e5-base-v2", | |
text_col: str = "text", | |
device=None, | |
batch_size=16, | |
) -> Tensor: | |
""" | |
Encode text data from a DataFrame or a list using a pre-trained transformer model. | |
Args: | |
text_data (Union[DataFrame, List[str]]): DataFrame or list containing the texts to be encoded. | |
model_name (str, optional): Pre-trained transformer model name. Defaults to 'intfloat/e5-base-v2'. | |
text_col (str, optional): Column name from which to extract texts if input is a DataFrame. Defaults to 'text'. | |
device (str, optional): The device to which tensors will be moved. Can be either "cuda" or "cpu". | |
batch_size (int, optional): Size of the batches in which the input data is split. Defaults to 16. | |
Returns: | |
Tensor: A tensor of embeddings for each input text. | |
Raises: | |
AssertionError: If the input is a DataFrame and the provided column name does not exist. | |
""" | |
# Check for GPU availability and set the device accordingly | |
device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
# Get the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name).to(device) | |
model = model.eval() | |
# Get the input texts directly from the dataframe column | |
if isinstance(text_data, DataFrame): | |
assert text_col in text_data.columns, f"column {text_col} not found in df" | |
input_texts = text_data[text_col].tolist() | |
else: | |
input_texts = text_data | |
return encode_texts(input_texts, tokenizer, model, device, batch_size) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script contains functions for embedding text data using BERTopic and KeyBERTInspired. | |
It loads a pre-trained BERT model and tokenizer, and uses them to generate embeddings for text data. | |
The embeddings are then used to cluster the text data using BERTopic. | |
The script also provides a function for generating keyphrases using KeyBERTInspired. | |
This script is designed to be run from the command line interface (CLI): | |
python embed_e5_cli.py <command> <args> | |
""" | |
import logging | |
import os | |
import pathlib | |
import pprint as pp | |
import subprocess | |
import fire | |
import joblib | |
import plotly.io as pio | |
import torch | |
import torch.nn.functional as F | |
from bertopic import BERTopic | |
from bertopic.backend import BaseEmbedder | |
from bertopic.representation import KeyBERTInspired | |
from datasets import load_dataset | |
from joblib import Parallel, delayed | |
from nltk.corpus import stopwords | |
from torch import Tensor | |
from tqdm import tqdm | |
from transformers import AutoModel, AutoTokenizer | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
def check_ampere_gpu(): | |
"""Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.""" | |
cmd = "nvidia-smi --query-gpu=name --format=csv,noheader" | |
try: | |
output = subprocess.check_output(cmd, shell=True, universal_newlines=True) | |
gpu_names = output.strip().split("\n") | |
except Exception as e: | |
logger.error(f"Error: {e}") | |
return | |
supported_gpus = ["A100", "A6000", "RTX 30"] # Add more models if needed | |
for gpu_name in gpu_names: | |
if any(supported_gpu in gpu_name for supported_gpu in supported_gpus): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
logger.info( | |
f"{gpu_name} supports NVIDIA Ampere or later. Enabled TF32 in PyTorch." | |
) | |
else: | |
logger.info(f"{gpu_name} does not support NVIDIA Ampere or later.") | |
def drop_stopwords(summary): | |
updated_summary = [] | |
for s in summary: | |
words = s.split() | |
words = [word for word in words if word.lower() not in stopwords] | |
updated_summary.append(" ".join(words)) | |
return updated_summary | |
def encode_texts(texts: list, tokenizer, model, device, batch_size=16) -> Tensor: | |
embeddings = [] | |
for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"): | |
batch_texts = texts[i : i + batch_size] | |
batch_dict = tokenizer( | |
batch_texts, | |
max_length=512, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
batch_dict = {k: v.to(device) for k, v in batch_dict.items()} | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
batch_embeddings = average_pool( | |
outputs.last_hidden_state, batch_dict["attention_mask"] | |
) | |
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
embeddings.append(batch_embeddings) | |
embeddings = torch.cat(embeddings, dim=0) | |
return embeddings.cpu() | |
def encode_textdata( | |
text_data, | |
model_name="intfloat/e5-base-v2", | |
text_col="text", | |
device=None, | |
batch_size=16, | |
) -> Tensor: | |
device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name).to(device) | |
model = model.eval() | |
if isinstance(text_data, list): | |
input_texts = text_data | |
else: | |
input_texts = text_data[text_col].tolist() | |
return encode_texts(input_texts, tokenizer, model, device, batch_size) | |
class CustomEmbedder(BaseEmbedder): | |
def __init__( | |
self, | |
model_name_or_path: str = "intfloat/e5-small-v2", | |
device=None, | |
batch_size=16, | |
): | |
super().__init__() | |
self.batch_size = 16 | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device) | |
self.model = self.model.eval() | |
def embed(self, documents, verbose=False) -> Tensor: | |
embeddings = [] | |
batch_size = self.batch_size | |
for i in tqdm(range(0, len(documents), batch_size), desc="Embedding documents"): | |
batch_texts = documents[i : i + batch_size] | |
batch_dict = self.tokenizer( | |
batch_texts, | |
max_length=512, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
).to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model(**batch_dict) | |
batch_embeddings = average_pool( | |
outputs.last_hidden_state, batch_dict["attention_mask"] | |
) | |
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | |
embeddings.append(batch_embeddings) | |
embeddings = torch.cat(embeddings, dim=0) | |
return embeddings.cpu().numpy() | |
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: | |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
def preprocess_dataset(dataset, filter_stopwords): | |
if filter_stopwords: | |
ds_train = dataset["train"].map( | |
lambda example: {"summary": drop_stopwords(example["summary"])}, | |
batched=True, | |
) | |
else: | |
ds_train = dataset["train"] | |
logger.info("Not removing stopwords") | |
return ds_train | |
def preprocess_data(dataset_name, filter_stopwords): | |
dataset = load_dataset(dataset_name) | |
ds_short_name = dataset_name.split("/")[-1] | |
ds_train = preprocess_dataset(dataset, filter_stopwords) | |
return ds_train, ds_short_name | |
def preprocess_documents( | |
ds_train, docs_column, embedding_model_name, split_by_maxlen=False | |
): | |
representation_model = KeyBERTInspired() | |
summ_docs = ds_train[docs_column] | |
summ_docs = list(set(summ_docs)) | |
emb_short_name = embedding_model_name.split("/")[-1] | |
split_path = pathlib.Path(f"summ_docs_split-{emb_short_name}.joblib") | |
if split_path.exists() and split_by_maxlen: | |
logger.info(f"Loading existing file: {split_path}") | |
summ_docs_split = joblib.load(split_path) | |
else: | |
if split_by_maxlen: | |
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) | |
summ_docs_split = split_text_by_tokens_parallel( | |
summ_docs, tokenizer, n_jobs=max(os.cpu_count() - 4, 1) | |
) | |
joblib.dump(summ_docs_split, split_path, compress=5) | |
logger.info(f"Saved file: {split_path}") | |
else: | |
summ_docs_split = summ_docs | |
return summ_docs_split | |
def split_text(text, tokenizer): | |
encoded_text = tokenizer.encode_plus( | |
text, | |
add_special_tokens=False, | |
truncation=False, | |
padding=False, | |
return_tensors="pt", | |
max_length=10**6, | |
) | |
input_ids_list = encoded_text["input_ids"].view(-1).tolist() | |
chunks = [ | |
input_ids_list[i : i + tokenizer.model_max_length] | |
for i in range(0, len(input_ids_list), tokenizer.model_max_length) | |
] | |
split_strings = [ | |
tokenizer.decode( | |
chunk, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
for chunk in chunks | |
] | |
return split_strings | |
def split_text_by_tokens_parallel(text_list, tokenizer, n_jobs=-1, backend="loky"): | |
logger.info(f"Running token splitting with n_jobs {n_jobs}, backend={backend}") | |
results = Parallel(n_jobs=n_jobs, backend=backend)( | |
delayed(split_text)(text, tokenizer) | |
for text in tqdm(text_list, desc="Splitting up text") | |
) | |
split_strings = [text for sublist in results for text in sublist] | |
return split_strings | |
def main( | |
dataset_name, | |
filter_stopwords=False, | |
docs_column: str = "target", | |
split_by_maxlen: bool = False, | |
calculate_probabilities: bool = False, | |
embedding_model_id: str = "intfloat/e5-base-v2", | |
nr_topics: int = None, | |
use_manual_embedder: bool = True, | |
min_topic_size=30, | |
save_html: bool = False, | |
): | |
""" | |
Main function for embedding E5 CLI. | |
Args: | |
dataset_name (str): Name of the dataset to use. | |
filter_stopwords (bool, optional): Whether to filter stopwords. Defaults to False. | |
docs_column (str, optional): Name of the column containing the documents. Defaults to "target". | |
split_by_maxlen (bool, optional): Whether to split documents by maximum length. Defaults to False. | |
calculate_probabilities (bool, optional): Whether to calculate probabilities. Defaults to False. | |
embedding_model_id (str, optional): ID of the embedding model to use. Defaults to "intfloat/e5-base-v2". | |
nr_topics (int, optional): Number of topics to use. Defaults to None == "auto". | |
use_manual_embedder (bool, optional): Whether to use a manual embedder. Defaults to True. | |
min_topic_size (int, optional): Minimum size of a topic. Defaults to 30. | |
save_html (bool, optional): Whether to save HTML. Defaults to False. | |
Returns: | |
None | |
""" | |
check_ampere_gpu() | |
nr_topics = nr_topics or "auto" | |
# Preprocess dataset | |
ds_train, ds_short_name = preprocess_data( | |
dataset_name, | |
filter_stopwords, | |
) | |
# Preprocess documents | |
summ_docs_split = preprocess_documents( | |
ds_train, docs_column, embedding_model_id, split_by_maxlen=split_by_maxlen | |
) | |
representation_model = KeyBERTInspired() | |
_kb_emb_model = ( | |
CustomEmbedder(embedding_model_id) | |
if use_manual_embedder | |
else embedding_model_id | |
) | |
topic_model = BERTopic( | |
"english", | |
verbose=True, | |
calculate_probabilities=calculate_probabilities, | |
embedding_model=_kb_emb_model, | |
n_gram_range=(1, 1), | |
nr_topics=nr_topics, | |
min_topic_size=min_topic_size, | |
representation_model=representation_model, | |
) | |
logging.info(f"Params:\n{pp.pformat(topic_model.get_params())}") | |
emb_short_name = embedding_model_id.split("/")[-1] | |
# Encode documents | |
tm_out_docs, tm_out_probs = topic_model.fit_transform(summ_docs_split) | |
# Push to Hugging Face Hub | |
model_name = f"pszemraj/BERTopic-{ds_short_name}-{emb_short_name}-{docs_column}" | |
topic_model.push_to_hf_hub( | |
model_name, private=True, save_embedding_model=True, save_ctfidf=True | |
) | |
plot_base_name = f"{ds_short_name}-{emb_short_name}-{docs_column}" | |
if save_html: | |
# Visualize topics | |
fig_topics = topic_model.visualize_topics( | |
title=f"<b>Intertopic Distance Map:<br><br>{docs_column} via {emb_short_name}</b>" | |
) | |
pio.write_html(fig_topics, plot_base_name + "topics.html") | |
fig_docs = topic_model.visualize_documents( | |
summ_docs_split, | |
title=f"<b>Documents and Topics - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_html(fig_docs, plot_base_name + "documents.html") | |
fig_hierarchy = topic_model.visualize_hierarchy( | |
top_n_topics=30, | |
title=f"<b>Hierarchy - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_html(fig_hierarchy, plot_base_name + "hierarchy.html") | |
fig_barchart = topic_model.visualize_barchart( | |
top_n_topics=8, | |
title=f"<b>Topic Word Scores - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_html(fig_barchart, plot_base_name + "barchart.html") | |
else: | |
# Visualize topics | |
fig_topics = topic_model.visualize_topics( | |
title=f"<b>Intertopic Distance Map:<br><br>{docs_column} via {emb_short_name}</b>" | |
) | |
pio.write_image(fig_topics, plot_base_name + "topics.jpg") | |
fig_docs = topic_model.visualize_documents( | |
summ_docs_split, | |
title=f"<b>Documents and Topics - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_image(fig_docs, plot_base_name + "documents.jpg") | |
fig_hierarchy = topic_model.visualize_hierarchy( | |
top_n_topics=30, | |
title=f"<b>Hierarchy - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_image(fig_hierarchy, plot_base_name + "hierarchy.jpg") | |
fig_barchart = topic_model.visualize_barchart( | |
top_n_topics=8, | |
title=f"<b>Topic Word Scores - <br>{docs_column} via {emb_short_name}</b>", | |
) | |
pio.write_image(fig_barchart, plot_base_name + "barchart.jpg") | |
logging.info(f"Pushed model to {model_name}") | |
logging.info(f"Done!") | |
if __name__ == "__main__": | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment