Created
April 16, 2019 05:40
-
-
Save MLWhiz/ddd23771948608020ba422158f58913d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SimpleDialogueManager(object): | |
""" | |
This is a simple dialogue manager to test the telegram bot. | |
The main part of our bot will be written here. | |
""" | |
def __init__(self): | |
# Instantiate all the models and TFIDF Objects. | |
print("Loading resources...") | |
# Instantiate a Chatterbot for Chitchat type questions | |
from chatterbot import ChatBot | |
from chatterbot.trainers import ChatterBotCorpusTrainer | |
chatbot = ChatBot('MLWhizChatterbot') | |
trainer = ChatterBotCorpusTrainer(chatbot) | |
trainer.train('chatterbot.corpus.english') | |
self.chitchat_bot = chatbot | |
print("Loading Word2vec model...") | |
# Instantiate the Google's pre-trained Word2Vec model. | |
self.model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True) | |
print("Loading Classifier objects...") | |
# Load the intent classifier and tag classifier | |
self.intent_recognizer = pickle.load(open('resources/intent_clf.pkl', 'rb')) | |
self.tag_classifier = pickle.load(open('resources/tag_clf.pkl', 'rb')) | |
# Load the TFIDF vectorizer object | |
self.tfidf_vectorizer = pickle.load(open('resources/tfidf.pkl', 'rb')) | |
print("Finished Loading Resources") | |
def get_similar_question(self,question,tag): | |
# get the path where all question embeddings are kept and load the post_ids and post_embeddings | |
embeddings_path = 'resources/embeddings_folder/' + tag + ".pkl" | |
post_ids, post_embeddings = pickle.load(open(embeddings_path, 'rb')) | |
# Get the embeddings for the question | |
question_vec = question_to_vec(question, self.model, 300) | |
# find index of most similar post | |
best_post_index = pairwise_distances_argmin(question_vec, | |
post_embeddings) | |
# return best post id | |
return post_ids[best_post_index] | |
def generate_answer(self, question): | |
prepared_question = text_prepare(question) | |
features = self.tfidf_vectorizer.transform([prepared_question]) | |
# find intent | |
intent = self.intent_recognizer.predict(features)[0] | |
# Chit-chat part: | |
if intent == 'dialogue': | |
response = self.chitchat_bot.get_response(question) | |
# Stack Overflow Question | |
else: | |
# find programming language | |
tag = self.tag_classifier.predict(features)[0] | |
# find most similar question post id | |
post_id = self.get_similar_question(question,tag)[0] | |
# respond with | |
response = 'I think its about %s\nThis thread might help you: https://stackoverflow.com/questions/%s' % (tag, post_id) | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment