Skip to content

Instantly share code, notes, and snippets.

@riveSunder
Last active July 12, 2023 18:09
Show Gist options
  • Select an option

  • Save riveSunder/5e85dfd850792c29fa1272f256067993 to your computer and use it in GitHub Desktop.

Select an option

Save riveSunder/5e85dfd850792c29fa1272f256067993 to your computer and use it in GitHub Desktop.
import argparse
# idioms and their meanings are from https://en.wikipedia.org/wiki/English-language_idioms
# you'll need hugging face transformers, numpy, pytorch, and matplotlib for this demo
import numpy as np
from transformers import pipeline, AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
def cls_pooling(model_output):
cls_vectors = model_output.last_hidden_state[:,0,:]
return cls_vectors
def mean_pooling(model_output):
mean_vectors = model_output.last_hidden_state.mean(1)
return mean_vectors
def l2_distance(query, embeddings):
# query is a 1xk vector,
# embeddings is n vectors in nxk matrix
distances = [((query-e)**2).sum().sqrt() \
for e in embeddings]
return distances
def cosine_similarity(query, embeddings):
q = query
similarities = [(q @ e.t()) \
/ (q @ q.t() * e @ e.t()).sqrt() \
for e in embeddings]
return similarities
def cosine_distance(query, embeddings):
similarities = cosine_similarity(query, embeddings)
cosine_distances = [1.0 - (s/2. + 0.5) \
for s in similarities]
return cosine_distances
def main(args):
with torch.no_grad():
my_model="sentence-transformers/multi-qa-mpnet-base-dot-v1"
extraction = pipeline("feature-extraction", model=my_model)
embedding = extraction("This is a test")
pad_length = 256
if args.simple_demo:
idioms = ["a bitter pill to swallow", \
"a dime a dozen",\
"a hot potato",\
"Achilles' heel", \
"dollars to donuts", \
"at the drop of a hat"]
else:
idioms = [\
"a bitter pill to swallow",
"a dime a dozen",
"a hot potato",
"a sandwich short of a picnic",
"ace in the hole",
"Achilles' heel",
"all ears",
"all thumbs",
"an arm and a leg",
"apple of discord",
"around the clock",
"as queer as a [strange object]",
"at the drop of a hat",
"back to the drawing board",
"back to the grindstone",
"ball is in his/her/your court",
"balls to the wall",
"barking up the wrong tree",
"basket case",
"beating a dead horse",
"beat around the bush",
"bed of roses",
"the bee's knees",
"best of both worlds",
"bird brain",
"bite off more than one can chew",
"bite the bullet",
]
if args.all_compare:
query = [\
"A situation or information that is unpleasant but must be accepted",
"Anything that is common, inexpensive, and easy to get or available anywhere",
"A controversial issue or situation that is awkward or unpleasant to deal with",
"Lacking intelligence",
"A hidden or secret strength; an unrevealed advantage",
"A small but fatal weakness in spite of overall strength",
"Listening intently; fully focused or awaiting an explanation",
"Clumsy, awkward",
"Very expensive or costly; a large amount of money",
"Anything causing trouble, discord, or jealousy",
"When something is done all day and all night without stopping",
"Something particularly strange or unusual",
"Without any hesitation; instantly",
"Revising something (such as a plan) from the beginning, typically after it has failed",
"To return to a hard and/or tedious task",
"It is up to him/her/you to make the next decision or step",
"Full throttle; at maximum speed",
"Looking in the wrong place",
"One made powerless or ineffective, as by nerves, panic, or stress",
"To uselessly dwell on a subject far beyond its point of resolution",
"To treat a topic but omit its main points, often intentionally or to delay or avoid talking about something difficult or unpleasant",
"A situation or activity that is comfortable or easy",
"Something or someone outstandingly good, excellent, or wonderful",
"A combination of two seemingly contradictory benefits",
"A person who is not too smart; a person who acts stupid",
"To take on more responsibility than one can manage",
"To endure a painful or unpleasant situation that is unavoidable",
]
else:
query = args.query
if type(query) is list:
pass
else:
query = [query]
tokenizer = AutoTokenizer.from_pretrained(my_model)
model = AutoModel.from_pretrained(my_model)
idiom_tokens = tokenizer(idioms, padding="max_length", \
max_length=pad_length,\
return_tensors="pt")
def_tokens = tokenizer(query, padding="max_length", \
max_length=pad_length,\
return_tensors="pt")
encoded_idioms = {key: value for key, value \
in idiom_tokens.items()}
encoded_def = {key:value \
for key, value in def_tokens.items()}
output_idioms = model(**encoded_idioms)
output_def = model(**encoded_def)
idiom_embeddings = mean_pooling(output_idioms)
query_embeddings = mean_pooling(output_def)
cosine_matrix = np.zeros((len(query), len(idioms)))
distance_matrix = np.zeros((len(query), len(idioms)))
for idx, query_embedding in enumerate(query_embeddings):
cosine_similarities = cosine_similarity(query_embedding, idiom_embeddings)
distances = l2_distance(query_embedding, idiom_embeddings)
cosine_matrix[idx, :] = cosine_similarities#.detach().numpy()
distance_matrix[idx, :] = distances#.detach().numpy()
if args.verbose == 2:
print(f"\nquery: {query[idx]}")
print("\n\t l2 distance \t cosine similarity")
for d, s, phrase in zip(distances, \
cosine_similarities, idioms):
print(f" {phrase}:\n\t\t {d:.3f} \t\t {s.item():.3f}")
print("")
if args.verbose:
k = args.top_k
indices_cosine = list(np.argsort(cosine_similarities))
indices_l2 = list(np.argsort(distances))
# higher is better for similarity
indices_cosine.reverse()
# lower is better for distance
top_k_cosine = np.array(idioms)[indices_cosine]
top_k_l2 = np.array(idioms)[indices_l2]
top_cosine = np.array(cosine_similarities)[indices_cosine]
top_l2 = np.array(distances)[indices_l2]
print(f"\nquery: {query[idx]} \n top-{k} by cosine similarity")
for phrase, s in zip(top_k_cosine[:k], top_cosine[:k]):
print(f" {phrase}\n\t\t {s:.3f}")
print(f"\nquery: {query[idx]} \n top-{k} by l2 distance similarity")
for phrase, d in zip(top_k_l2[:k], top_l2[:k]):
print(f" {phrase}\n\t\t {d:.3f}")
if args.display_figure:
plt.figure(figsize=(12,6))
plt.subplot(121)
plt.imshow(cosine_matrix, cmap="inferno"); plt.colorbar()
plt.title("Cosine similarity\n(brighter is better)", fontsize=32)
plt.xlabel("idioms", fontsize=28)
plt.ylabel("meanings", fontsize=28)
plt.subplot(122)
plt.imshow(distance_matrix, cmap="gray"); plt.colorbar()
plt.title("Euclidean distance\n(darker is better)", fontsize=32)
plt.xlabel("idioms", fontsize=28)
plt.ylabel("meanings", fontsize=28)
plt.tight_layout()
if args.display_figure ==2:
# use to save figure if desired
plt.savefig("idioms_plot.png")
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--all_compare",\
action="store_true",\
help="compares all meanings against all idioms. if not set, user supplied quer(y/ies) (is/are) used. "\
)
parser.add_argument("-c", "--cls_pooling",\
action="store_true",\
help="use CLS pooling. if this flag is not set, mean pooling is used."\
)
parser.add_argument("-d", "--display_figure", type=int,\
default=0,\
help="level of figure display: 0 - no plot, 1- show plot, 2- save plot ('idioms_plot.png') and show"\
)
parser.add_argument("-k", "--top_k", type=int, \
default=3, \
help="k for printing out top k matches"\
)
parser.add_argument("-q", "--query", type=str, nargs="+",\
default=["fatal weakness, especially in the absence of other vulnerabilities"], \
help="query phrase (or space-separated phrases)" \
)
parser.add_argument("-s", "--simple_demo", \
action="store_true",\
help="use this flag for a simple vector search of 6 idioms"\
)
parser.add_argument("-v", "--verbose", type=int, \
default=1, \
help="0 - quiet, 1- prints out top scores 2- prints out all scores"\
)
args = parser.parse_args()
print(args)
main(args)
@riveSunder
Copy link
Author

python demo.py -a -d 2
idioms_plot

python demo.py -q 'a fatal weakness surrounded by strength and invulnerability' -k 5 -v 1
output:

query: a fatal weakness surrounded by strength and invulnerability 
 top-5 by cosine similarity
  Achilles' heel
                 0.696
  ace in the hole
                 0.675
  beating a dead horse
                 0.661
  the bee's knees
                 0.640
  a bitter pill to swallow
                 0.640

query: a fatal weakness surrounded by strength and invulnerability 
 top-5 by l2 distance similarity
  Achilles' heel
                 4.993
  ace in the hole
                 5.150
  beating a dead horse
                 5.213
  bite the bullet
                 5.375
  at the drop of a hat
                 5.375

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment