Skip to content

Instantly share code, notes, and snippets.

@julien-c
Created March 3, 2020 00:28
Show Gist options
  • Save julien-c/5985591f7c9097b94f84dcf441a93117 to your computer and use it in GitHub Desktop.
Save julien-c/5985591f7c9097b94f84dcf441a93117 to your computer and use it in GitHub Desktop.
import logging
from transformers.modeling_bart import BartForSequenceClassification
from transformers.pipelines import TextClassificationPipeline
from transformers.tokenization_bart import BartTokenizer
logging.basicConfig(level=logging.INFO)
model = BartForSequenceClassification.from_pretrained("bart-large-mnli")
tokenizer = BartTokenizer.from_pretrained("bart-large-mnli")
pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
x = ["BART is a seq2seq model.", "BART is not sequence to sequence."]
y = ["BART is denoising autoencoder.", "BART is version of autoencoder."]
print(pipeline([x]))
# [{'label': 'contradiction', 'score': 0.9971421}]
print(pipeline([y]))
# [{'label': 'entailment', 'score': 0.93996036}]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment