Last active
July 12, 2023 18:09
-
-
Save riveSunder/5e85dfd850792c29fa1272f256067993 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| # idioms and their meanings are from https://en.wikipedia.org/wiki/English-language_idioms | |
| # you'll need hugging face transformers, numpy, pytorch, and matplotlib for this demo | |
| import numpy as np | |
| from transformers import pipeline, AutoTokenizer, AutoModel | |
| import matplotlib.pyplot as plt | |
| def cls_pooling(model_output): | |
| cls_vectors = model_output.last_hidden_state[:,0,:] | |
| return cls_vectors | |
| def mean_pooling(model_output): | |
| mean_vectors = model_output.last_hidden_state.mean(1) | |
| return mean_vectors | |
| def l2_distance(query, embeddings): | |
| # query is a 1xk vector, | |
| # embeddings is n vectors in nxk matrix | |
| distances = [((query-e)**2).sum().sqrt() \ | |
| for e in embeddings] | |
| return distances | |
| def cosine_similarity(query, embeddings): | |
| q = query | |
| similarities = [(q @ e.t()) \ | |
| / (q @ q.t() * e @ e.t()).sqrt() \ | |
| for e in embeddings] | |
| return similarities | |
| def cosine_distance(query, embeddings): | |
| similarities = cosine_similarity(query, embeddings) | |
| cosine_distances = [1.0 - (s/2. + 0.5) \ | |
| for s in similarities] | |
| return cosine_distances | |
| def main(args): | |
| with torch.no_grad(): | |
| my_model="sentence-transformers/multi-qa-mpnet-base-dot-v1" | |
| extraction = pipeline("feature-extraction", model=my_model) | |
| embedding = extraction("This is a test") | |
| pad_length = 256 | |
| if args.simple_demo: | |
| idioms = ["a bitter pill to swallow", \ | |
| "a dime a dozen",\ | |
| "a hot potato",\ | |
| "Achilles' heel", \ | |
| "dollars to donuts", \ | |
| "at the drop of a hat"] | |
| else: | |
| idioms = [\ | |
| "a bitter pill to swallow", | |
| "a dime a dozen", | |
| "a hot potato", | |
| "a sandwich short of a picnic", | |
| "ace in the hole", | |
| "Achilles' heel", | |
| "all ears", | |
| "all thumbs", | |
| "an arm and a leg", | |
| "apple of discord", | |
| "around the clock", | |
| "as queer as a [strange object]", | |
| "at the drop of a hat", | |
| "back to the drawing board", | |
| "back to the grindstone", | |
| "ball is in his/her/your court", | |
| "balls to the wall", | |
| "barking up the wrong tree", | |
| "basket case", | |
| "beating a dead horse", | |
| "beat around the bush", | |
| "bed of roses", | |
| "the bee's knees", | |
| "best of both worlds", | |
| "bird brain", | |
| "bite off more than one can chew", | |
| "bite the bullet", | |
| ] | |
| if args.all_compare: | |
| query = [\ | |
| "A situation or information that is unpleasant but must be accepted", | |
| "Anything that is common, inexpensive, and easy to get or available anywhere", | |
| "A controversial issue or situation that is awkward or unpleasant to deal with", | |
| "Lacking intelligence", | |
| "A hidden or secret strength; an unrevealed advantage", | |
| "A small but fatal weakness in spite of overall strength", | |
| "Listening intently; fully focused or awaiting an explanation", | |
| "Clumsy, awkward", | |
| "Very expensive or costly; a large amount of money", | |
| "Anything causing trouble, discord, or jealousy", | |
| "When something is done all day and all night without stopping", | |
| "Something particularly strange or unusual", | |
| "Without any hesitation; instantly", | |
| "Revising something (such as a plan) from the beginning, typically after it has failed", | |
| "To return to a hard and/or tedious task", | |
| "It is up to him/her/you to make the next decision or step", | |
| "Full throttle; at maximum speed", | |
| "Looking in the wrong place", | |
| "One made powerless or ineffective, as by nerves, panic, or stress", | |
| "To uselessly dwell on a subject far beyond its point of resolution", | |
| "To treat a topic but omit its main points, often intentionally or to delay or avoid talking about something difficult or unpleasant", | |
| "A situation or activity that is comfortable or easy", | |
| "Something or someone outstandingly good, excellent, or wonderful", | |
| "A combination of two seemingly contradictory benefits", | |
| "A person who is not too smart; a person who acts stupid", | |
| "To take on more responsibility than one can manage", | |
| "To endure a painful or unpleasant situation that is unavoidable", | |
| ] | |
| else: | |
| query = args.query | |
| if type(query) is list: | |
| pass | |
| else: | |
| query = [query] | |
| tokenizer = AutoTokenizer.from_pretrained(my_model) | |
| model = AutoModel.from_pretrained(my_model) | |
| idiom_tokens = tokenizer(idioms, padding="max_length", \ | |
| max_length=pad_length,\ | |
| return_tensors="pt") | |
| def_tokens = tokenizer(query, padding="max_length", \ | |
| max_length=pad_length,\ | |
| return_tensors="pt") | |
| encoded_idioms = {key: value for key, value \ | |
| in idiom_tokens.items()} | |
| encoded_def = {key:value \ | |
| for key, value in def_tokens.items()} | |
| output_idioms = model(**encoded_idioms) | |
| output_def = model(**encoded_def) | |
| idiom_embeddings = mean_pooling(output_idioms) | |
| query_embeddings = mean_pooling(output_def) | |
| cosine_matrix = np.zeros((len(query), len(idioms))) | |
| distance_matrix = np.zeros((len(query), len(idioms))) | |
| for idx, query_embedding in enumerate(query_embeddings): | |
| cosine_similarities = cosine_similarity(query_embedding, idiom_embeddings) | |
| distances = l2_distance(query_embedding, idiom_embeddings) | |
| cosine_matrix[idx, :] = cosine_similarities#.detach().numpy() | |
| distance_matrix[idx, :] = distances#.detach().numpy() | |
| if args.verbose == 2: | |
| print(f"\nquery: {query[idx]}") | |
| print("\n\t l2 distance \t cosine similarity") | |
| for d, s, phrase in zip(distances, \ | |
| cosine_similarities, idioms): | |
| print(f" {phrase}:\n\t\t {d:.3f} \t\t {s.item():.3f}") | |
| print("") | |
| if args.verbose: | |
| k = args.top_k | |
| indices_cosine = list(np.argsort(cosine_similarities)) | |
| indices_l2 = list(np.argsort(distances)) | |
| # higher is better for similarity | |
| indices_cosine.reverse() | |
| # lower is better for distance | |
| top_k_cosine = np.array(idioms)[indices_cosine] | |
| top_k_l2 = np.array(idioms)[indices_l2] | |
| top_cosine = np.array(cosine_similarities)[indices_cosine] | |
| top_l2 = np.array(distances)[indices_l2] | |
| print(f"\nquery: {query[idx]} \n top-{k} by cosine similarity") | |
| for phrase, s in zip(top_k_cosine[:k], top_cosine[:k]): | |
| print(f" {phrase}\n\t\t {s:.3f}") | |
| print(f"\nquery: {query[idx]} \n top-{k} by l2 distance similarity") | |
| for phrase, d in zip(top_k_l2[:k], top_l2[:k]): | |
| print(f" {phrase}\n\t\t {d:.3f}") | |
| if args.display_figure: | |
| plt.figure(figsize=(12,6)) | |
| plt.subplot(121) | |
| plt.imshow(cosine_matrix, cmap="inferno"); plt.colorbar() | |
| plt.title("Cosine similarity\n(brighter is better)", fontsize=32) | |
| plt.xlabel("idioms", fontsize=28) | |
| plt.ylabel("meanings", fontsize=28) | |
| plt.subplot(122) | |
| plt.imshow(distance_matrix, cmap="gray"); plt.colorbar() | |
| plt.title("Euclidean distance\n(darker is better)", fontsize=32) | |
| plt.xlabel("idioms", fontsize=28) | |
| plt.ylabel("meanings", fontsize=28) | |
| plt.tight_layout() | |
| if args.display_figure ==2: | |
| # use to save figure if desired | |
| plt.savefig("idioms_plot.png") | |
| plt.show() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-a", "--all_compare",\ | |
| action="store_true",\ | |
| help="compares all meanings against all idioms. if not set, user supplied quer(y/ies) (is/are) used. "\ | |
| ) | |
| parser.add_argument("-c", "--cls_pooling",\ | |
| action="store_true",\ | |
| help="use CLS pooling. if this flag is not set, mean pooling is used."\ | |
| ) | |
| parser.add_argument("-d", "--display_figure", type=int,\ | |
| default=0,\ | |
| help="level of figure display: 0 - no plot, 1- show plot, 2- save plot ('idioms_plot.png') and show"\ | |
| ) | |
| parser.add_argument("-k", "--top_k", type=int, \ | |
| default=3, \ | |
| help="k for printing out top k matches"\ | |
| ) | |
| parser.add_argument("-q", "--query", type=str, nargs="+",\ | |
| default=["fatal weakness, especially in the absence of other vulnerabilities"], \ | |
| help="query phrase (or space-separated phrases)" \ | |
| ) | |
| parser.add_argument("-s", "--simple_demo", \ | |
| action="store_true",\ | |
| help="use this flag for a simple vector search of 6 idioms"\ | |
| ) | |
| parser.add_argument("-v", "--verbose", type=int, \ | |
| default=1, \ | |
| help="0 - quiet, 1- prints out top scores 2- prints out all scores"\ | |
| ) | |
| args = parser.parse_args() | |
| print(args) | |
| main(args) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
python demo.py -a -d 2python demo.py -q 'a fatal weakness surrounded by strength and invulnerability' -k 5 -v 1output: