Last active
May 10, 2022 20:40
-
-
Save davidmezzetti/3e6d622d41abf322bcddbc5c6ae414d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import requests | |
def transform(inputs): | |
response = requests.post("https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/nli-mpnet-base-v2", | |
json={"inputs": inputs}) | |
return np.array(response.json(), dtype=np.float32) | |
# Index data using vectors from Inference API | |
embeddings = Embeddings({"method": "external", "transform": transform, "content": True}) | |
embeddings.index([(uid, text, None) for uid, text in enumerate(data)]) | |
print("%-20s %s" % ("Query", "Best Match")) | |
print("-" * 50) | |
# Run an embeddings search for each query | |
for query in ("feel good story", "climate change", "public health story", "war", "wildlife", "asia", "lucky", "dishonest junk"): | |
# Extract text field from result | |
text = embeddings.search(f"select id, text, score from txtai where similar('{query}')", 1)[0]["text"] | |
# Print text | |
print("%-20s %s" % (query, text)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment