Created
June 12, 2021 17:00
-
-
Save talhaanwarch/f7cb08037ff59bdcd85df57c61ee94b4 to your computer and use it in GitHub Desktop.
CALCULATE SENTENCE SIMILARITY using Pretrained BERT model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri Jun 11 18:58:05 2021 | |
# CALCULATE SENTENCE SIMILARITY | |
@author: TAC | |
""" | |
import torch#pytorch | |
from transformers import AutoTokenizer, AutoModel#for embeddings | |
from sklearn.metrics.pairwise import cosine_similarity#for similarity | |
#download pretrained model | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",) | |
model = AutoModel.from_pretrained("bert-base-uncased",output_hidden_states=True) | |
#create embeddings | |
def get_embeddings(text,token_length): | |
tokens=tokenizer(text,max_length=token_length,padding='max_length',truncation=True) | |
output=model(torch.tensor(tokens.input_ids).unsqueeze(0), | |
attention_mask=torch.tensor(tokens.attention_mask).unsqueeze(0)).hidden_states[-1] | |
return torch.mean(output,axis=1).detach().numpy() | |
#calculate similarity | |
def calculate_similarity(text1,text2,token_length=20): | |
text3=input('input you sentence \n') | |
out1=get_embeddings(text1,token_length=token_length)#create embeddings of text | |
out2=get_embeddings(text2,token_length=token_length)#create embeddings of text | |
out3=get_embeddings(text3,token_length=token_length)#create embeddings of text | |
sim1= cosine_similarity(out1,out3)[0][0] | |
sim2= cosine_similarity(out2,out3)[0][0] | |
print(sim1,sim2) | |
if sim1>sim2: | |
print('sentence 1 is more similar to input sentence') | |
else: | |
print('sentence 2 is more similar to input sentence') | |
text1='Before viewing the output, let understand the parameters the tokenizer takes' | |
text2='if the token length is smaller than the token in a sentence then remove some of the tokens to make them equal in length' | |
calculate_similarity(text1,text2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment