Skip to content

Instantly share code, notes, and snippets.

@tadasgedgaudas
Created May 19, 2021 06:32
Show Gist options
  • Save tadasgedgaudas/188904dafdc1e119ff2c7fbe3616d0d3 to your computer and use it in GitHub Desktop.
Save tadasgedgaudas/188904dafdc1e119ff2c7fbe3616d0d3 to your computer and use it in GitHub Desktop.
from difflib import SequenceMatcher
from transformers import AutoModelForCausalLM, AutoTokenizer
# Blacklist phrases we don't want
blacklisted_answers = ["I'm not sure", "I'm sorry,", "I'm not a "]
def similar(first_string, second_string):
"""
Checks how similar two strings are
:param first_string: string to compare with second string
:param second_string: string to be compared with first string
:return: similarity from 0 to 1
"""
return SequenceMatcher(None, first_string, second_string).ratio()
# Download the model
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
def get_nlp_answer(text):
"""
Generates an answer or comment based on given input
:param text: input to generate an answer/comment from
:return: generated answer from input
"""
tokenized_text = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
generated_answer = model.generate(tokenized_text, max_length=1000, pad_token_id=tokenizer.eos_token_id)
answer = "{}".format(tokenizer.decode(generated_answer[-1], skip_special_tokens=True))
answer = answer.replace(text, "")
if any(blacklisted_answer in answer for blacklisted_answer in blacklisted_answers):
return None
if similar(text, answer) >= 0.7:
return None
return answer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment