Created
January 4, 2025 12:49
-
-
Save shreyansh26/d7eb4458b89567778477e1e544743e84 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sklearn | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import torch.nn.functional as F | |
def get_score_diff(vectors): | |
scores = torch.matmul(vectors, vectors.T) | |
scores = scores[torch.triu(torch.ones_like(scores), diagonal=1).bool()] | |
score_diff = scores.reshape((1, -1)) - scores.reshape((-1, 1)) | |
score_diff = score_diff[torch.triu(torch.ones_like(score_diff), diagonal=1).bool()] | |
return score_diff | |
def get_triplet_loss(teacher_vectors, student_vectors, triplet_margin=0.15): | |
triplet_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1) | |
triplet_score_diff = get_score_diff(student_vectors) * triplet_label | |
triplet_score_diff_with_margin = triplet_score_diff + triplet_margin | |
triplet_loss = F.relu(triplet_score_diff_with_margin) | |
triplet_loss = triplet_loss.mean() | |
return triplet_loss | |
# Triplet loss - `loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0)` | |
# Sample sentences - first and third are quite similar | |
sentences = ["A man is eating a piece of bread.", "A man is riding a horse.", "A man is eating food."] | |
# Taking two similar models just for demonstration | |
teacher_model = SentenceTransformer('mixedbread-ai/mxbai-embed-large-v1', device=0) | |
student_model = SentenceTransformer('mixedbread-ai/mxbai-embed-xsmall-v1', device=0) | |
teacher_vectors = teacher_model.encode(sentences) | |
student_vectors = student_model.encode(sentences) | |
# Normalize embeddings | |
teacher_vectors = F.normalize(torch.tensor(teacher_vectors), p=2, dim=1) | |
student_vectors = F.normalize(torch.tensor(student_vectors), p=2, dim=1) | |
print(teacher_vectors) | |
# tensor([[ 0.0197, -0.0210, -0.0139, ..., 0.0017, -0.0085, 0.0261], | |
# [ 0.0283, 0.0133, 0.0458, ..., 0.0178, 0.0238, 0.0034], | |
# [ 0.0183, -0.0171, -0.0099, ..., -0.0072, 0.0146, 0.0038]]) | |
print(sklearn.metrics.pairwise.cosine_similarity(teacher_vectors)) | |
# [[1.0000001 0.35981846 0.819199 ] | |
# [0.35981846 0.9999999 0.38630775] | |
# [0.819199 0.38630775 1. ]] | |
# This is the difference in cosine similarity between pairs of sentence embeddings from the teacher model | |
score_diff = get_score_diff(teacher_vectors) | |
print(score_diff) | |
# [cos_sim(sent1, sent3) - cos_sim(sent1, sent2), - sent1 is anchor | |
# cos_sim(sent2, sent3) - cos_sim(sent1, sent2), - sent2 is anchor | |
# cos_sim(sent2, sent3) - cos_sim(sent1, sent3)] - sent3 is anchor | |
# tensor([ 0.4594, 0.0265, -0.4329]) | |
# Define triplet margin | |
triplet_margin = 0.15 | |
# Define triplet label - reasoning becomes clearer in the next step | |
triplet_label = torch.where(get_score_diff(teacher_vectors) < 0, 1, -1) | |
print(triplet_label) | |
# tensor([-1, -1, 1]) | |
# This is the difference in cosine similarity between pairs of sentence embeddings from the student model | |
triplet_loss = get_score_diff(student_vectors) | |
print(triplet_loss) | |
# tensor([ 0.5816, 0.0959, -0.4856]) | |
# Convert all positive scores (i.e. triplets where the teacher model thinks there is a big gap between the | |
# simailarity of positive and negative senteces with anchor text) to negative (as we don't know the ordering | |
# of positive and negative sentences) | |
# and keep negative scores as is | |
triplet_loss = triplet_loss * triplet_label | |
print(triplet_loss) | |
# tensor([-0.5816, -0.0959, -0.4856]) | |
# Add the margin to all scores | |
triplet_loss += triplet_margin | |
print(triplet_loss) | |
# tensor([-0.4316, 0.0541, -0.3356]) | |
# If a score is still negative, it means that it is correctly working as a positive-negative-anchor triplet for the mode | |
# If the score is positive, then it means that the embedding similarity of the anchor and negative sentence is not large enough | |
# or the embedding similarity of the anchor and positive sentence is not small enough. | |
# This component needs to be a part of the loss. | |
triplet_loss = F.relu(triplet_loss) | |
print(triplet_loss) | |
# tensor([0.0000, 0.0541, 0.0000]) | |
triplet_loss = triplet_loss.mean() | |
print(triplet_loss) | |
# tensor(0.0180) | |
print(get_triplet_loss(teacher_vectors, student_vectors)) | |
# tensor(0.0180) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment