Skip to content

Instantly share code, notes, and snippets.

@davidmezzetti
Last active October 20, 2021 15:56
Show Gist options
  • Save davidmezzetti/63511943cd986c6869a2fa0e7ee6693e to your computer and use it in GitHub Desktop.
Save davidmezzetti/63511943cd986c6869a2fa0e7ee6693e to your computer and use it in GitHub Desktop.
from txtai.pipeline import MLOnnx, Similarity
def tokenize(inputs, **kwargs):
if isinstance(inputs, str):
inputs = [inputs]
return {"input_ids": [[x] for x in inputs]}
# Export to ONNX
onnx = MLOnnx()
skmodel = onnx(pipeline)
# Load models into similarity pipeline
similarity = Similarity((skmodel, tokenize), dynamic=False)
# Add labels to model
similarity.pipeline.model.config.id2label = {0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}
similarity.pipeline.model.config.label2id = dict((v, k) for k, v in similarity.pipeline.model.config.id2label.items())
# Run an embeddings search for each query
for query in ("joy", "anger", "surprise"):
print(query)
# Print top 3 results
for uid, score in similarity(query, data, None)[:3]:
print(data[uid], score)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment