Created
May 8, 2017 15:55
-
-
Save ajfisch/d447b9fd610b1d9868843350145fc6e3 to your computer and use it in GitHub Desktop.
Interactive drqa model with ParlAI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run: python path/to/file --pretrained_model path/to/model | |
# | |
# Example interaction: | |
# Context: I was thirsty today. So I went to the market and bought some water. | |
# Question: What did I buy? | |
# Reply: some water | |
import torch | |
import logging | |
from parlai.agents.drqa.agents import DocReaderAgent | |
from parlai.core.params import ParlaiParser | |
def main(opt): | |
# Load document reader (need pretrained model) | |
assert('pretrained_model' in opt) | |
doc_reader = DocReaderAgent(opt) | |
# Log params | |
logger.info('[ Created with options: ] %s' % | |
''.join(['\n{}\t{}'.format(k, v) for k, v in opt.items()])) | |
while True: | |
context = input('Context: ') | |
question = input('Question: ') | |
observation = {'text': '\n'.join([context, question]), | |
'episode_done': True} | |
doc_reader.observe(observation) | |
reply = doc_reader.act() | |
print('Reply: %s' % reply['text']) | |
if __name__ == '__main__': | |
# Get command line arguments | |
argparser = ParlaiParser() | |
DocReaderAgent.add_cmdline_args(argparser) | |
opt = argparser.parse_args() | |
# Set logging (only stderr) | |
logger = logging.getLogger('DrQA') | |
logger.setLevel(logging.INFO) | |
fmt = logging.Formatter('%(asctime)s: %(message)s', '%m/%d/%Y %I:%M:%S %p') | |
console = logging.StreamHandler() | |
console.setFormatter(fmt) | |
logger.addHandler(console) | |
# Set cuda | |
opt['cuda'] = not opt['no_cuda'] and torch.cuda.is_available() | |
if opt['cuda']: | |
logger.info('[ Using CUDA (GPU %d) ]' % opt['gpu']) | |
torch.cuda.set_device(opt['gpu']) | |
# Run! | |
main(opt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment