Last active
December 12, 2023 05:18
-
-
Save pszemraj/1a87f2fa7ca2a1bf47cb7d6041cb24bd to your computer and use it in GitHub Desktop.
generic & basic sbert-like embedder class for the jina-bert model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
generic & basic sbert-like embedder class for the jina-bert model | |
Usage: | |
model = EmbeddingModel("jinaai/jina-embeddings-v2-base-en") | |
embeddings = model.encode( | |
["How is the weather today?", "What is the current weather like today?"] | |
) | |
print(model.cos_sim(embeddings[0], embeddings[1])) | |
""" | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from numpy.linalg import norm | |
import numpy as np | |
from tqdm.auto import trange, tqdm | |
class EmbeddingModel: | |
""" | |
A generic and basic SBERT-like embedding class using the Jina AI model. | |
Attributes: | |
model_name (str): Name of the model to be used for embeddings. | |
device (str): The device (CPU/GPU) on which the model runs. | |
batch_size (int): Batch size for processing inputs. | |
tokenizer: Tokenizer corresponding to the model. | |
model: The embedding model loaded from Hugging Face. | |
Methods: | |
encode(sentences): Encodes a list of sentences into embeddings. | |
cos_sim(a, b): Computes cosine similarity between two vectors. | |
""" | |
MAX_REASONABLE_LENGTH = 8192 | |
def __init__( | |
self, | |
model_name: str = "jinaai/jina-embeddings-v2-small-en", | |
device: str = None, | |
batch_size: int = 8, | |
max_length: int = None, | |
compile: bool = True, | |
): | |
""" | |
Initializes the EmbeddingModel with a specified model, device, and batch size. | |
Args: | |
model_name (str): The model to use for embeddings, default is 'jinaai/jina-embeddings-v2-small-en'. | |
device (str): The device to run the model on ('cuda' for GPU or 'cpu'), defaults to GPU if available. | |
batch_size (int): Size of batches for processing, default is 32. | |
""" | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.batch_size = batch_size | |
self.compile = compile | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained( | |
model_name, trust_remote_code=True, torch_dtype="auto" | |
).to(self.device) | |
self.max_length = max_length or self.tokenizer.model_max_length | |
if self.compile: | |
self.model = torch.compile( | |
self.model, | |
) | |
def __repr__(self): | |
""" | |
Returns a string representation of the EmbeddingModel instance. | |
""" | |
return f"EmbeddingModel(model_name={self.model.config.name_or_path}, device={self.device}, batch_size={self.batch_size})" | |
def encode(self, sentences, max_length: int = None): | |
""" | |
Encodes a list of sentences into embeddings using the model. | |
Args: | |
sentences (list of str): A list of sentences to be encoded. | |
Returns: | |
numpy.ndarray: An array of sentence embeddings. | |
""" | |
self.model.eval() | |
embeddings = [] | |
with torch.no_grad(): | |
for i in trange(0, len(sentences), self.batch_size, desc="encoding text"): | |
batch = sentences[i : i + self.batch_size] | |
inputs = self.tokenizer( | |
batch, | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=min(self.max_length, self.MAX_REASONABLE_LENGTH), | |
).to(self.device) | |
outputs = self.model(**inputs) | |
# SBERT style pooling | |
input_mask_expanded = ( | |
inputs["attention_mask"] | |
.unsqueeze(-1) | |
.expand(outputs.last_hidden_state.shape) | |
.float() | |
) | |
sum_embeddings = torch.sum( | |
outputs.last_hidden_state * input_mask_expanded, 1 | |
) | |
sum_mask = input_mask_expanded.sum(1) | |
sum_mask = torch.clamp(sum_mask, min=1e-9) | |
mean_embeddings = sum_embeddings / sum_mask | |
embeddings.append(mean_embeddings.cpu().numpy()) | |
embeddings = np.concatenate(embeddings, axis=0) | |
return embeddings | |
@staticmethod | |
def cos_sim(a, b): | |
return (a @ b.T) / (norm(a) * norm(b)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment