Created
November 6, 2017 16:38
-
-
Save jgc128/94ee90d32f953f85f6e2aaa6fd200ee3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tokenize_story_line(story_line): | |
tokens = re.findall(r'(\w+)', story_line) | |
return tokens | |
def load_stories(file): | |
data = [] | |
with file.open() as f: | |
story = [] | |
for line in f: | |
story_line_id, story_line = line.lower().split(' ', 1) | |
story_line_id = int(story_line_id) | |
if story_line_id == 1: | |
# new story starts | |
story = [] | |
if '\t' not in story_line: | |
# just a regular story line | |
story_line = tokenize_story_line(story_line) | |
story.append(story_line) | |
else: | |
# question line | |
# as we reach the question, add the current story up to this moment to data list | |
question, answer, supporting_ids = story_line.split('\t') | |
question = tokenize_story_line(question) | |
current_story = [sl for sl in story] | |
data.append((current_story, question, answer)) | |
return data | |
def load_task_data(data_dir, task_id): | |
p = Path(data_dir) | |
train_file = next(p.glob(f'qa{task_id}_*_train.txt')) | |
test_file = next(p.glob(f'qa{task_id}_*_test.txt')) | |
logging.info(f'Train file: {train_file}') | |
logging.info(f'Test file: {test_file}') | |
train_stories = load_stories(train_file) | |
test_stories = load_stories(test_file) | |
logging.info(f'Train stories: {len(train_stories)}, test stories: {len(test_stories)}') | |
return train_stories, test_stories |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment