Skip to content

Instantly share code, notes, and snippets.

@kinoc
Created February 6, 2022 11:10
Show Gist options
  • Save kinoc/8d6015014fbfab4f48ee2561d509d3e2 to your computer and use it in GitHub Desktop.
Save kinoc/8d6015014fbfab4f48ee2561d509d3e2 to your computer and use it in GitHub Desktop.
Contradiction Detector using an HF roberta project by ynie @ FB/Meta
from transformers import AutoTokenizer, AutoModelForSequenceClassification # 4.0.1
import torch # 1.7
# use "contra activate dual" on laptop
# https://github.com/facebookresearch/ParlAI/issues/3391
# https://github.com/facebookresearch/ParlAI/issues/3665
# https://arxiv.org/abs/2012.13391
# https://huggingface.co/ynie/roberta-large_conv_contradiction_detector_v0
if __name__ == '__main__':
max_length = 256
hg_model_hub_name = "ynie/roberta-large_conv_contradiction_detector_v0"
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)
model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
premiseList =["I'm an Spanish teacher.",
"I am at home.",
"I love my job.",
"I have no siblings",
"I found the treasure.",
"I am alone.",
"I am single.",
"I know python."]
# premise = "I'm an Spanish teacher."
while(True):
hypothesis = input("Hypo:") # "I don't know how to speak Spanish.","I only speak English","I am married",...
for premise in premiseList:
tokenized_input_seq_pair = tokenizer.encode_plus(premise, hypothesis,
max_length=max_length,
return_token_type_ids=True, truncation=True)
input_ids = torch.Tensor(tokenized_input_seq_pair['input_ids']).long().unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = torch.Tensor(tokenized_input_seq_pair['token_type_ids']).long().unsqueeze(0)
attention_mask = torch.Tensor(tokenized_input_seq_pair['attention_mask']).long().unsqueeze(0)
outputs = model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
labels=None)
predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist() # batch_size only one
#print("Premise:", premise)
# print("Hypothesis:", hypothesis)
# print("Non contradiction:", predicted_probability[0])
# print("Contradiction:", predicted_probability[1])
thresh = 0.7
if ( predicted_probability[1]>predicted_probability[0]):
if (predicted_probability[1]> thresh):
print(" |--->:",premise, predicted_probability[1],predicted_probability[0])
else:
print(" | :",premise, predicted_probability[1],predicted_probability[0])
else:
if (predicted_probability[0]> thresh):
print(" | *:",premise, predicted_probability[1],predicted_probability[0])
else:
print(" | :",premise, predicted_probability[1],predicted_probability[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment