Skip to content

Instantly share code, notes, and snippets.

Created January 4, 2025 12:49
Show Gist options
  • Save shreyansh26/d7eb4458b89567778477e1e544743e84 to your computer and use it in GitHub Desktop.
Save shreyansh26/d7eb4458b89567778477e1e544743e84 to your computer and use it in GitHub Desktop.
import sklearn
import torch
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
def get_score_diff(vectors):
scores = torch.matmul(vectors, vectors.T)
scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()]
score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1))
score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()]
return score_diff
def get_triplet_loss(teacher_vectors, student_vectors, triplet_margin=0.15):
triplet_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1)
triplet_score_diff = get_score_diff(student_vectors) * triplet_label
triplet_score_diff_with_margin = triplet_score_diff + triplet_margin
triplet_loss = F.relu(triplet_score_diff_with_margin)
triplet_loss = triplet_loss.mean()
return triplet_loss
# Triplet loss - `loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0)`
# Sample sentences - first and third are quite similar
sentences = ["A man is eating a piece of bread.", "A man is riding a horse.", "A man is eating food."]
# Taking two similar models just for demonstration
teacher_model = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1', device=0)
student_model = SentenceTransformer('mixedbread-ai/mxbai-embed-xsmall-v1', device=0)
teacher_vectors = teacher_model.encode(sentences)
student_vectors = student_model.encode(sentences)
# Normalize embeddings
teacher_vectors = F.normalize(torch.tensor(teacher_vectors), p=2, dim=1)
student_vectors = F.normalize(torch.tensor(student_vectors), p=2, dim=1)
# tensor([[ 0.0197, -0.0210, -0.0139, ..., 0.0017, -0.0085, 0.0261],
# [ 0.0283, 0.0133, 0.0458, ..., 0.0178, 0.0238, 0.0034],
# [ 0.0183, -0.0171, -0.0099, ..., -0.0072, 0.0146, 0.0038]])
# [[1.0000001 0.35981846 0.819199 ]
# [0.35981846 0.9999999 0.38630775]
# [0.819199 0.38630775 1. ]]
# This is the difference in cosine similarity between pairs of sentence embeddings from the teacher model
score_diff = get_score_diff(teacher_vectors)
# [cos_sim(sent1, sent3) - cos_sim(sent1, sent2), - sent1 is anchor
# cos_sim(sent2, sent3) - cos_sim(sent1, sent2), - sent2 is anchor
# cos_sim(sent2, sent3) - cos_sim(sent1, sent3)] - sent3 is anchor
# tensor([ 0.4594, 0.0265, -0.4329])
# Define triplet margin
triplet_margin = 0.15
# Define triplet label - reasoning becomes clearer in the next step
triplet_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1)
# tensor([-1, -1, 1])
# This is the difference in cosine similarity between pairs of sentence embeddings from the student model
triplet_loss = get_score_diff(student_vectors)
# tensor([ 0.5816, 0.0959, -0.4856])
# Convert all positive scores (i.e. triplets where the teacher model thinks there is a big gap between the
# simailarity of positive and negative senteces with anchor text) to negative (as we don't know the ordering
# of positive and negative sentences)
# and keep negative scores as is
triplet_loss = triplet_loss * triplet_label
# tensor([-0.5816, -0.0959, -0.4856])
# Add the margin to all scores
triplet_loss += triplet_margin
# tensor([-0.4316, 0.0541, -0.3356])
# If a score is still negative, it means that it is correctly working as a positive-negative-anchor triplet for the mode
# If the score is positive, then it means that the embedding similarity of the anchor and negative sentence is not large enough
# or the embedding similarity of the anchor and positive sentence is not small enough.
# This component needs to be a part of the loss.
triplet_loss = F.relu(triplet_loss)
# tensor([0.0000, 0.0541, 0.0000])
triplet_loss = triplet_loss.mean()
# tensor(0.0180)
print(get_triplet_loss(teacher_vectors, student_vectors))
# tensor(0.0180)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment