Last active
May 9, 2019 01:20
-
-
Save seanie12/ca967cb0d8148bb844aa21fa6a123ac1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pytorch_pretrained_bert.tokenization import whitespace_tokenize | |
import collections | |
from copy import deepcopy | |
class SquadExample(object): | |
""" | |
A single training/test example for the Squad dataset. | |
For examples without an answer, the start and end position are -1. | |
""" | |
def __init__(self, | |
qas_id, | |
question_text, | |
doc_tokens, | |
orig_answer_text=None, | |
start_position=None, | |
end_position=None, | |
is_impossible=None): | |
self.qas_id = qas_id | |
self.question_text = question_text | |
self.doc_tokens = doc_tokens | |
self.orig_answer_text = orig_answer_text | |
self.start_position = start_position | |
self.end_position = end_position | |
self.is_impossible = is_impossible | |
def __str__(self): | |
return self.__repr__() | |
def __repr__(self): | |
s = "" | |
s += "qas_id: %s" % self.qas_id | |
s += ", question_text: %s" % self.question_text | |
s += ", doc_tokens: [%s]" % " ".join(self.doc_tokens) | |
if self.start_position: | |
s += ", start_position: %d" % self.start_position | |
if self.end_position: | |
s += ", end_position: %d" % self.end_position | |
if self.is_impossible: | |
s += ", is_impossible: %r" % self.is_impossible | |
return s | |
def convert_examples_to_features(examples, tokenizer, max_seq_length, | |
doc_stride, max_query_length, is_training): | |
"""Loads a data file into a list of `InputBatch`s.""" | |
unique_id = 1000000000 | |
features = [] | |
for (example_index, example) in enumerate(examples): | |
query_tokens = tokenizer.tokenize(example.question_text) | |
if len(query_tokens) > max_query_length: | |
query_tokens = query_tokens[0:max_query_length] | |
tok_to_orig_index = [] | |
orig_to_tok_index = [] | |
all_doc_tokens = [] | |
for (i, token) in enumerate(example.doc_tokens): | |
orig_to_tok_index.append(len(all_doc_tokens)) | |
sub_tokens = tokenizer.tokenize(token) | |
for sub_token in sub_tokens: | |
tok_to_orig_index.append(i) | |
all_doc_tokens.append(sub_token) | |
tok_start_position = None | |
tok_end_position = None | |
if is_training and example.is_impossible: | |
tok_start_position = -1 | |
tok_end_position = -1 | |
if is_training and not example.is_impossible: | |
tok_start_position = orig_to_tok_index[example.start_position] | |
if example.end_position < len(example.doc_tokens) - 1: | |
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 | |
else: | |
tok_end_position = len(all_doc_tokens) - 1 | |
(tok_start_position, tok_end_position) = _improve_answer_span( | |
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, | |
example.orig_answer_text) | |
# The -3 accounts for [CLS], [SEP] and [SEP] | |
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 | |
# We can have documents that are longer than the maximum sequence length. | |
# To deal with this we do a sliding window approach, where we take chunks | |
# of the up to our max length with a stride of `doc_stride`. | |
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name | |
"DocSpan", ["start", "length"]) | |
doc_spans = [] | |
start_offset = 0 | |
while start_offset < len(all_doc_tokens): | |
length = len(all_doc_tokens) - start_offset | |
if length > max_tokens_for_doc: | |
length = max_tokens_for_doc | |
doc_spans.append(_DocSpan(start=start_offset, length=length)) | |
if start_offset + length == len(all_doc_tokens): | |
break | |
start_offset += min(length, doc_stride) | |
for (doc_span_index, doc_span) in enumerate(doc_spans): | |
tokens = [] | |
token_to_orig_map = {} | |
token_is_max_context = {} | |
segment_ids = [] | |
tokens.append("[CLS]") | |
segment_ids.append(0) | |
for token in query_tokens: | |
tokens.append(token) | |
segment_ids.append(0) | |
tokens.append("[SEP]") | |
segment_ids.append(0) | |
context_tokens = list() | |
context_tokens.append("[CLS]") | |
for i in range(doc_span.length): | |
split_token_index = doc_span.start + i | |
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] | |
is_max_context = _check_is_max_context(doc_spans, doc_span_index, | |
split_token_index) | |
token_is_max_context[len(tokens)] = is_max_context | |
tokens.append(all_doc_tokens[split_token_index]) | |
segment_ids.append(1) | |
context_tokens.append(all_doc_tokens[split_token_index]) | |
tokens.append("[SEP]") | |
segment_ids.append(1) | |
context_tokens.append("[SEP]") | |
input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
# The mask has 1 for real tokens and 0 for padding tokens. Only real | |
# tokens are attended to. | |
input_mask = [1] * len(input_ids) | |
# Zero-pad up to the sequence length. | |
while len(input_ids) < max_seq_length: | |
input_ids.append(0) | |
input_mask.append(0) | |
segment_ids.append(0) | |
assert len(input_ids) == max_seq_length | |
assert len(input_mask) == max_seq_length | |
assert len(segment_ids) == max_seq_length | |
start_position = None | |
end_position = None | |
noq_start_position = None | |
noq_end_position = None | |
if is_training and not example.is_impossible: | |
# For training, if our document chunk does not contain an annotation | |
# we throw it out, since there is nothing to predict. | |
doc_start = doc_span.start | |
doc_end = doc_span.start + doc_span.length - 1 | |
out_of_span = False | |
if not (tok_start_position >= doc_start and | |
tok_end_position <= doc_end): | |
out_of_span = True | |
if out_of_span: | |
start_position = 0 | |
end_position = 0 | |
noq_start_position = 0 | |
noq_end_position = 0 | |
else: | |
doc_offset = len(query_tokens) + 2 | |
start_position = tok_start_position - doc_start + doc_offset | |
end_position = tok_end_position - doc_start + doc_offset | |
# plus one for [CLS] token | |
noq_start_position = tok_start_position - doc_start + 1 | |
noq_end_position = tok_end_position - doc_start + 1 | |
if out_of_span: | |
continue | |
if is_training and example.is_impossible: | |
start_position = 0 | |
end_position = 0 | |
noq_start_position = 0 | |
noq_end_position = 0 | |
q_tokens = deepcopy(query_tokens) | |
q_tokens.insert(0, "[CLS]") | |
q_tokens.append("[SEP]") | |
q_ids = tokenizer.convert_tokens_to_ids(q_tokens) | |
c_ids = tokenizer.convert_tokens_to_ids(context_tokens) | |
# pad up to maximum length | |
while len(q_ids) < max_query_length: | |
q_ids.append(0) | |
while len(c_ids) < max_seq_length: | |
c_ids.append(0) | |
# BIO tagging scheme | |
tag_ids = [0] * len(c_ids) # Outside | |
if noq_start_position is not None and noq_end_position is not None: | |
tag_ids[noq_start_position] = 1 # Begin | |
# Inside tag | |
for idx in range(noq_start_position + 1, noq_end_position + 1): | |
tag_ids[idx] = 2 | |
assert len(tag_ids) == len(c_ids), "length of tag :{}, length of c :{}".format(len(tag_ids), len(c_ids)) | |
features.append( | |
InputFeatures( | |
unique_id=unique_id, | |
example_index=example_index, | |
doc_span_index=doc_span_index, | |
tokens=tokens, | |
token_to_orig_map=token_to_orig_map, | |
token_is_max_context=token_is_max_context, | |
input_ids=input_ids, | |
input_mask=input_mask, | |
c_ids=c_ids, | |
context_tokens=context_tokens, | |
q_ids=q_ids, | |
q_tokens=q_tokens, | |
answer_text=example.orig_answer_text, | |
tag_ids=tag_ids, | |
segment_ids=segment_ids, | |
noq_start_position=noq_start_position, | |
noq_end_position=noq_end_position, | |
start_position=start_position, | |
end_position=end_position, | |
is_impossible=example.is_impossible)) | |
unique_id += 1 | |
return features | |
class InputFeatures(object): | |
"""A single set of features of data.""" | |
def __init__(self, | |
unique_id, | |
example_index, | |
doc_span_index, | |
tokens, | |
token_to_orig_map, | |
token_is_max_context, | |
input_ids, | |
c_ids, | |
context_tokens, | |
q_ids, | |
q_tokens, | |
tag_ids, | |
input_mask, | |
segment_ids, | |
noq_start_position=None, | |
noq_end_position=None, | |
start_position=None, | |
end_position=None, | |
is_impossible=None): | |
self.unique_id = unique_id | |
self.example_index = example_index | |
self.doc_span_index = doc_span_index | |
self.tokens = tokens | |
self.token_to_orig_map = token_to_orig_map | |
self.token_is_max_context = token_is_max_context | |
self.input_ids = input_ids | |
self.c_ids = c_ids | |
self.context_tokens = context_tokens | |
self.q_ids = q_ids | |
self.q_tokens = q_tokens | |
self.tag_ids = tag_ids | |
self.input_mask = input_mask | |
self.segment_ids = segment_ids | |
self.noq_start_position = noq_start_position | |
self.noq_end_position = noq_end_position | |
self.start_position = start_position | |
self.end_position = end_position | |
self.is_impossible = is_impossible | |
def read_squad_examples(input_file, is_training, version_2_with_negative=False, debug=False): | |
"""Read a SQuAD json file into a list of SquadExample.""" | |
with open(input_file, "r", encoding='utf-8') as reader: | |
input_data = json.load(reader)["data"] | |
def is_whitespace(c): | |
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: | |
return True | |
return False | |
examples = [] | |
if debug: | |
input_data = input_data[:10] | |
for entry in input_data: | |
for paragraph in entry["paragraphs"]: | |
paragraph_text = paragraph["context"] | |
doc_tokens = [] | |
char_to_word_offset = [] | |
prev_is_whitespace = True | |
for c in paragraph_text: | |
if is_whitespace(c): | |
prev_is_whitespace = True | |
else: | |
if prev_is_whitespace: | |
doc_tokens.append(c) | |
else: | |
doc_tokens[-1] += c | |
prev_is_whitespace = False | |
char_to_word_offset.append(len(doc_tokens) - 1) | |
for qa in paragraph["qas"]: | |
qas_id = qa["id"] | |
question_text = qa["question"] | |
start_position = None | |
end_position = None | |
orig_answer_text = None | |
is_impossible = False | |
if is_training: | |
if version_2_with_negative: | |
is_impossible = qa["is_impossible"] | |
# if (len(qa["answers"]) != 1) and (not is_impossible): | |
# raise ValueError( | |
# "For training, each question should have exactly 1 answer.") | |
if not is_impossible: | |
answer = qa["answers"][0] | |
orig_answer_text = answer["text"] | |
answer_offset = answer["answer_start"] | |
answer_length = len(orig_answer_text) | |
start_position = char_to_word_offset[answer_offset] | |
end_position = char_to_word_offset[answer_offset + answer_length - 1] | |
# Only add answers where the text can be exactly recovered from the | |
# document. If this CAN'T happen it's likely due to weird Unicode | |
# stuff so we will just skip the example. | |
# | |
# Note that this means for training mode, every example is NOT | |
# guaranteed to be preserved. | |
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) | |
cleaned_answer_text = " ".join( | |
whitespace_tokenize(orig_answer_text)) | |
if actual_text.find(cleaned_answer_text) == -1: | |
continue | |
else: | |
start_position = -1 | |
end_position = -1 | |
orig_answer_text = "" | |
example = SquadExample( | |
qas_id=qas_id, | |
question_text=question_text, | |
doc_tokens=doc_tokens, | |
orig_answer_text=orig_answer_text, | |
start_position=start_position, | |
end_position=end_position, | |
is_impossible=is_impossible) | |
examples.append(example) | |
return examples | |
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, | |
orig_answer_text): | |
"""Returns tokenized answer spans that better match the annotated answer.""" | |
# The SQuAD annotations are character based. We first project them to | |
# whitespace-tokenized words. But then after WordPiece tokenization, we can | |
# often find a "better match". For example: | |
# | |
# Question: What year was John Smith born? | |
# Context: The leader was John Smith (1895-1943). | |
# Answer: 1895 | |
# | |
# The original whitespace-tokenized answer will be "(1895-1943).". However | |
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match | |
# the exact answer, 1895. | |
# | |
# However, this is not always possible. Consider the following: | |
# | |
# Question: What country is the top exporter of electornics? | |
# Context: The Japanese electronics industry is the lagest in the world. | |
# Answer: Japan | |
# | |
# In this case, the annotator chose "Japan" as a character sub-span of | |
# the word "Japanese". Since our WordPiece tokenizer does not split | |
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare | |
# in SQuAD, but does happen. | |
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) | |
for new_start in range(input_start, input_end + 1): | |
for new_end in range(input_end, new_start - 1, -1): | |
text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) | |
if text_span == tok_answer_text: | |
return (new_start, new_end) | |
return (input_start, input_end) | |
def _check_is_max_context(doc_spans, cur_span_index, position): | |
"""Check if this is the 'max context' doc span for the token.""" | |
# Because of the sliding window approach taken to scoring documents, a single | |
# token can appear in multiple documents. E.g. | |
# Doc: the man went to the store and bought a gallon of milk | |
# Span A: the man went to the | |
# Span B: to the store and bought | |
# Span C: and bought a gallon of | |
# ... | |
# | |
# Now the word 'bought' will have two scores from spans B and C. We only | |
# want to consider the score with "maximum context", which we define as | |
# the *minimum* of its left and right context (the *sum* of left and | |
# right context will always be the same, of course). | |
# | |
# In the example the maximum context for 'bought' would be span C since | |
# it has 1 left context and 3 right context, while span B has 4 left context | |
# and 0 right context. | |
best_score = None | |
best_span_index = None | |
for (span_index, doc_span) in enumerate(doc_spans): | |
end = doc_span.start + doc_span.length - 1 | |
if position < doc_span.start: | |
continue | |
if position > end: | |
continue | |
num_left_context = position - doc_span.start | |
num_right_context = end - position | |
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length | |
if best_score is None or score > best_score: | |
best_score = score | |
best_span_index = span_index | |
return cur_span_index == best_span_index | |
def get_data_loader(self, file): | |
train_examples = read_squad_examples(file, is_training=True, debug=config.debug) | |
train_features = convert_examples_to_features(train_examples, | |
tokenizer=self.tokenizer, | |
max_seq_length=config.max_seq_len, | |
max_query_length=config.max_query_len, | |
doc_stride=128, | |
is_training=True) | |
all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) | |
all_c_lens = torch.sum(torch.sign(all_c_ids), 1).long() | |
all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long) | |
all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long) | |
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) | |
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) | |
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) | |
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) | |
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) | |
train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids, all_q_ids, \ | |
all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) | |
sampler = RandomSampler(train_data) | |
train_loader = DataLoader(train_data, sampler=sampler, batch_size=config.batch_size) | |
return train_loader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment