Last active
October 20, 2021 15:56
-
-
Save davidmezzetti/63511943cd986c6869a2fa0e7ee6693e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from txtai.pipeline import MLOnnx, Similarity | |
def tokenize(inputs, **kwargs): | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
return {"input_ids": [[x] for x in inputs]} | |
# Export to ONNX | |
onnx = MLOnnx() | |
skmodel = onnx(pipeline) | |
# Load models into similarity pipeline | |
similarity = Similarity((skmodel, tokenize), dynamic=False) | |
# Add labels to model | |
similarity.pipeline.model.config.id2label = {0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"} | |
similarity.pipeline.model.config.label2id = dict((v, k) for k, v in similarity.pipeline.model.config.id2label.items()) | |
# Run an embeddings search for each query | |
for query in ("joy", "anger", "surprise"): | |
print(query) | |
# Print top 3 results | |
for uid, score in similarity(query, data, None)[:3]: | |
print(data[uid], score) | |
print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment