Created
March 3, 2020 00:28
-
-
Save julien-c/5985591f7c9097b94f84dcf441a93117 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from transformers.modeling_bart import BartForSequenceClassification | |
from transformers.pipelines import TextClassificationPipeline | |
from transformers.tokenization_bart import BartTokenizer | |
logging.basicConfig(level=logging.INFO) | |
model = BartForSequenceClassification.from_pretrained("bart-large-mnli") | |
tokenizer = BartTokenizer.from_pretrained("bart-large-mnli") | |
pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
x = ["BART is a seq2seq model.", "BART is not sequence to sequence."] | |
y = ["BART is denoising autoencoder.", "BART is version of autoencoder."] | |
print(pipeline([x])) | |
# [{'label': 'contradiction', 'score': 0.9971421}] | |
print(pipeline([y])) | |
# [{'label': 'entailment', 'score': 0.93996036}] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment