Created
October 30, 2023 10:02
-
-
Save AIAnytime/28ecf3f0b5c1e4b0b369d902be023200 to your computer and use it in GitHub Desktop.
semantic sim
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import BertTokenizer, BertModel | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
def encode(text): | |
tokens = tokenizer.tokenize(text) | |
tokens = ['[CLS]'] + tokens + ['[SEP]'] | |
return tokenizer.convert_tokens_to_ids(tokens) | |
ground_truth = """ | |
I am playing cricket because this is a passion for me. | |
""" | |
response_text = """ | |
My passion is the game which called cricket. | |
""" | |
ground_truth_ids = encode(ground_truth) | |
response_ids = encode(response_text) | |
def extract_embeddings(token_ids): | |
tokens_tensor = torch.tensor([token_ids]) | |
with torch.no_grad(): | |
outputs = model(tokens_tensor) | |
hidden_state = outputs[0] | |
return hidden_state[0].mean(dim=0) | |
ground_truth_emb = extract_embeddings(ground_truth_ids) | |
response_emb = extract_embeddings(response_ids) | |
cos = torch.nn.CosineSimilarity(dim=0) | |
similarity = cos(ground_truth_emb, response_emb).item() | |
print(similarity) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
torch
transformers