Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created June 2, 2022 21:36
Show Gist options
  • Save Steboss89/4af0bfa324a660c2cb0703b2ac947a3a to your computer and use it in GitHub Desktop.
Save Steboss89/4af0bfa324a660c2cb0703b2ac947a3a to your computer and use it in GitHub Desktop.
measure text similarity with Roberta
!pip install sentence_transformers
from sentence_transformers import SentenceTransformer, util
# use roberta
model = SentenceTransformer('stsb-roberta-large')
def create_heatmap(similarity, cmap = "YlGnBu"):
df = pd.DataFrame(similarity)
df.columns = ['john', 'luke','mark', 'matt'] #ohn 0 mark 2 matt 3 luke 1
df.index = ['john', 'luke','mark', 'matt']
fig, ax = plt.subplots(figsize=(5,5))
sns.heatmap(df, cmap=cmap)
# encode the input text
embeddings = model.encode(data, convert_to_tensor=True)
similarity = []
for i in range(len(data)):
row = []
for j in range(len(data)):
row.append(util.pytorch_cos_sim(embeddings[i], embeddings[j]).item())
similarity.append(row)
create_heatmap(similarity)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment