Created
July 1, 2024 01:03
-
-
Save rjurney/b41d57baf32cb8c3373bb060ec0979e0 to your computer and use it in GitHub Desktop.
Cosine similarity adaptation of Sentence-BERT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoModel, AutoTokenizer | |
class CosineSentenceBERT(nn.Module): | |
def __init__(self, model_name=SBERT_MODEL, dim=384): | |
super().__init__() | |
self.model_name = model_name | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
# Update the FFNN to output embedding dimension | |
self.ffnn = nn.Sequential( | |
nn.Linear(dim, dim), | |
nn.GELU(), | |
nn.Dropout(0.1), | |
) | |
@staticmethod | |
def mean_pool(token_embeds, attention_mask): | |
in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float() | |
pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9) | |
return pool | |
def encode(self, input_ids, attention_mask): | |
outputs = self.model(input_ids, attention_mask=attention_mask)[0] | |
embeddings = self.mean_pool(outputs, attention_mask) | |
return self.ffnn(embeddings) | |
def forward(self, input_ids_a, input_ids_b, attention_mask_a=None, attention_mask_b=None, labels=None): | |
# Encode both sentences | |
embed_a = self.encode(input_ids_a, attention_mask_a) | |
embed_b = self.encode(input_ids_b, attention_mask_b) | |
# Compute cosine similarity | |
cosine_sim = F.cosine_similarity(embed_a, embed_b) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.CosineEmbeddingLoss() | |
# CosineEmbeddingLoss expects 1 for similar pairs and -1 for dissimilar pairs | |
loss = loss_fct(embed_a, embed_b, (labels * 2) - 1) | |
return {"loss": loss, "similarity": cosine_sim} | |
def predict(self, a: str, b: str): | |
encoded_a = self.tokenizer(a, padding=True, truncation=True, return_tensors="pt") | |
encoded_b = self.tokenizer(b, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
embed_a = self.encode(encoded_a["input_ids"].to(self.model.device), | |
encoded_a["attention_mask"].to(self.model.device)) | |
embed_b = self.encode(encoded_b["input_ids"].to(self.model.device), | |
encoded_b["attention_mask"].to(self.model.device)) | |
similarity = F.cosine_similarity(embed_a, embed_b).item() | |
return similarity |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment