Created
May 19, 2021 06:32
-
-
Save tadasgedgaudas/188904dafdc1e119ff2c7fbe3616d0d3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from difflib import SequenceMatcher | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Blacklist phrases we don't want | |
blacklisted_answers = ["I'm not sure", "I'm sorry,", "I'm not a "] | |
def similar(first_string, second_string): | |
""" | |
Checks how similar two strings are | |
:param first_string: string to compare with second string | |
:param second_string: string to be compared with first string | |
:return: similarity from 0 to 1 | |
""" | |
return SequenceMatcher(None, first_string, second_string).ratio() | |
# Download the model | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
def get_nlp_answer(text): | |
""" | |
Generates an answer or comment based on given input | |
:param text: input to generate an answer/comment from | |
:return: generated answer from input | |
""" | |
tokenized_text = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt") | |
generated_answer = model.generate(tokenized_text, max_length=1000, pad_token_id=tokenizer.eos_token_id) | |
answer = "{}".format(tokenizer.decode(generated_answer[-1], skip_special_tokens=True)) | |
answer = answer.replace(text, "") | |
if any(blacklisted_answer in answer for blacklisted_answer in blacklisted_answers): | |
return None | |
if similar(text, answer) >= 0.7: | |
return None | |
return answer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment