Created
October 14, 2021 18:27
-
-
Save davidmezzetti/dea39e9157eb03cf39b82dd8a3c4d214 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.pipeline import Pipeline | |
from txtai.models import Models | |
from txtai.pipeline import MLOnnx | |
from transformers import pipeline | |
def tokenize(inputs, **kwargs): | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
return {"input_ids": [[x] for x in inputs]} | |
ds = load_dataset("glue", "sst2") | |
# Train the model | |
model = Pipeline([ | |
('tfidf', TfidfVectorizer()), | |
('lr', LogisticRegression(max_iter=250)) | |
]) | |
model.fit(ds["train"]["sentence"], ds["train"]["label"]) | |
# Convert model to ONNX | |
onnx = MLOnnx() | |
model = onnx(model) | |
# Run HF pipeline | |
nlp = pipeline("text-classification", model=Models.load(model), tokenizer=tokenize) | |
print(nlp(["That is great", "That is terrible"], function_to_apply="none")) | |
# [{'label': 'LABEL_1', 'score': 0.900902509689331}, {'label': 'LABEL_0', 'score': 0.9138815402984619}] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment