Created
May 14, 2019 01:06
-
-
Save seanie12/b533ea61d55ffaae87c045c5b1efb53e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from model import Seq2seq | |
import os | |
from squad_utils import read_squad_examples, convert_examples_to_features, write_predictions | |
from pytorch_pretrained_bert import BertTokenizer, BertForQuestionAnswering | |
import torch | |
from torch.utils.data import DataLoader, TensorDataset | |
import torch.nn.functional as F | |
import config | |
import collections | |
import re, string, sys, json | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r'\b(a|an|the)\b', ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def f1_score(prediction, ground_truth): | |
prediction_tokens = normalize_answer(prediction).split() | |
ground_truth_tokens = normalize_answer(ground_truth).split() | |
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens) | |
num_same = sum(common.values()) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(prediction_tokens) | |
recall = 1.0 * num_same / len(ground_truth_tokens) | |
f1 = (2 * precision * recall) / (precision + recall) | |
return f1 | |
def exact_match_score(prediction, ground_truth): | |
return (normalize_answer(prediction) == normalize_answer(ground_truth)) | |
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |
scores_for_ground_truths = [] | |
for ground_truth in ground_truths: | |
score = metric_fn(prediction, ground_truth) | |
scores_for_ground_truths.append(score) | |
return max(scores_for_ground_truths) | |
def evaluate(dataset, predictions): | |
f1 = exact_match = total = 0 | |
for article in dataset: | |
for paragraph in article['paragraphs']: | |
for qa in paragraph['qas']: | |
total += 1 | |
if qa['id'] not in predictions: | |
message = 'Unanswered question ' + qa['id'] + \ | |
' will receive score 0.' | |
print(message, file=sys.stderr) | |
continue | |
ground_truths = list(map(lambda x: x['text'], qa['answers'])) | |
prediction = predictions[qa['id']] | |
exact_match += metric_max_over_ground_truths( | |
exact_match_score, prediction, ground_truths) | |
f1 += metric_max_over_ground_truths( | |
f1_score, prediction, ground_truths) | |
exact_match = 100.0 * exact_match / total | |
f1 = 100.0 * f1 / total | |
return {'exact_match': exact_match, 'f1': f1} | |
class Hypothesis(object): | |
def __init__(self, tokens, log_probs, state, context=None): | |
self.tokens = tokens | |
self.log_probs = log_probs | |
self.state = state | |
self.context = context | |
def extend(self, token, log_prob, state, context=None): | |
h = Hypothesis(tokens=self.tokens + [token], | |
log_probs=self.log_probs + [log_prob], | |
state=state, | |
context=context) | |
return h | |
@property | |
def latest_token(self): | |
return self.tokens[-1] | |
@property | |
def avg_log_prob(self): | |
return sum(self.log_probs) / len(self.tokens) | |
class EvalFeature(object): | |
def __init__(self, | |
unique_id, | |
example_index, | |
tokens, | |
token_to_orig_map, | |
token_is_max_context, | |
input_ids, | |
input_mask, | |
segment_ids, | |
start_position=None, | |
end_position=None): | |
self.unique_id = unique_id | |
self.example_index = example_index | |
self.tokens = tokens | |
self.token_to_orig_map = token_to_orig_map | |
self.token_is_max_context = token_is_max_context | |
self.input_ids = input_ids | |
self.input_mask = input_mask | |
self.segment_ids = segment_ids | |
self.start_position = start_position | |
self.end_position = end_position | |
class BeamSearcher(object): | |
def __init__(self, model_path, output_dir): | |
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
self.output_dir = output_dir | |
self.golden_q_ids = None | |
self.all_c_tokens = None | |
self.all_answer_text = None | |
self.all_example_ids = None | |
self.all_unique_ids = None | |
self.all_c_tokens = None | |
self.all_token_to_orig_map = None | |
self.all_token_is_max_context = None | |
self.data_loader = self.get_data_loader("./squad/new_test-v1.1.json") | |
self.tok2idx = self.tokenizer.vocab | |
self.idx2tok = {idx: tok for tok, idx in self.tok2idx.items()} | |
self.model = Seq2seq(dropout=0.0, model_path=model_path, use_tag=config.use_tag) | |
self.model.requires_grad = False | |
self.model.eval_mode() | |
self.src_file = output_dir + "/src.txt" | |
self.pred_file = output_dir + "/generated.txt" | |
self.golden_file = output_dir + "/golden.txt" | |
self.ans_file = output_dir + "/answer.txt" | |
self.total_file = output_dir + "/all_files.csv" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
def get_data_loader(self, file): | |
train_examples = read_squad_examples(file, is_training=True, debug=config.debug) | |
train_features = convert_examples_to_features(train_examples, | |
tokenizer=self.tokenizer, | |
max_seq_length=config.max_seq_len, | |
max_query_length=config.max_query_len, | |
doc_stride=128, | |
is_training=True) | |
all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long) | |
all_c_lens = torch.sum(torch.sign(all_c_ids), 1) | |
all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long) | |
all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long) | |
all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long) | |
all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long) | |
# for qa eval | |
self.all_example_ids = [f.example_index for f in train_features] | |
self.all_unique_ids = [f.unique_id for f in train_features] | |
self.all_c_tokens = [f.context_tokens for f in train_features] | |
self.all_token_is_max_context = [f.token_is_max_context for f in train_features] | |
self.all_token_to_orig_map = [f.token_to_orig_map for f in train_features] | |
train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids, all_q_ids, | |
all_noq_start_positions, all_noq_end_positions) | |
train_loader = DataLoader(train_data, shuffle=False, batch_size=1) | |
self.all_answer_text = [f.answer_text for f in train_features] | |
self.golden_q_ids = all_q_ids | |
return train_loader | |
@staticmethod | |
def sort_hypotheses(hypotheses): | |
return sorted(hypotheses, key=lambda h: h.avg_log_prob, reverse=True) | |
def decode(self): | |
pred_fw = open(self.pred_file, "w") | |
golden_fw = open(self.golden_file, "w") | |
src_fw = open(self.src_file, "w") | |
ans_fw = open(self.ans_file, "w") | |
features = [] | |
for i, eval_data in enumerate(self.data_loader): | |
c_ids, c_lens, tag_seq, q_ids, noq_start_position, noq_end_position = eval_data | |
c_ids = c_ids.to(config.device) | |
c_lens = c_lens.to(config.device) | |
tag_seq = tag_seq.to(config.device) | |
if config.use_tag is False: | |
tag_seq = None | |
best_question = self.beam_search(c_ids, c_lens, tag_seq) | |
# discard START token | |
output_indices = [int(idx) for idx in best_question.tokens[1:]] | |
sep_idx = self.tokenizer.vocab["[SEP]"] | |
try: | |
fst_stop_idx = output_indices.index(sep_idx) | |
output_indices = output_indices[:fst_stop_idx] | |
except ValueError: | |
output_indices = output_indices | |
decoded_words = self.tokenizer.convert_ids_to_tokens(output_indices) | |
decoded_words = " ".join(decoded_words) | |
q_id = self.golden_q_ids[i] | |
q_len = torch.sum(torch.sign(q_ids), 1).item() | |
# discard [CLS], [SEP] and unnecessary PAD tokens | |
q_id = q_id[1:q_len - 1].cpu().numpy() | |
golden_question = self.tokenizer.convert_ids_to_tokens(q_id) | |
answer_text = self.all_answer_text[i] | |
# de-tokenize src tokens | |
src_tokens = self.all_c_tokens[i] | |
# discard [CLS] and [SEP] tokens | |
src_txt = " ".join(src_tokens[1:-1]) | |
src_txt = src_txt.replace(" ##", "") | |
src_txt = src_txt.replace("##", "").strip() | |
print("write {}th question".format(i)) | |
pred_fw.write(decoded_words + "\n") | |
golden_fw.write(" ".join(golden_question) + "\n") | |
src_fw.write(src_txt + "\n") | |
ans_fw.write(answer_text + "\n") | |
# construct eval examples | |
cls_id = self.tokenizer.vocab["[CLS]"] | |
input_ids = list() | |
segment_ids = list() | |
input_ids.append(cls_id) | |
segment_ids.append(0) | |
for q_id in output_indices: | |
input_ids.append(q_id) | |
segment_ids.append(0) | |
input_ids.append(sep_idx) | |
segment_ids.append(0) | |
# exclude [cls] but not [sep] | |
c_ids = c_ids.cpu().tolist()[1:] | |
for c_id in c_ids: | |
input_ids.append(c_id) | |
segment_ids.append(1) | |
input_mask = [1] * len(input_ids) | |
while len(input_ids) < config.max_seq_len: | |
input_ids.append(0) | |
input_mask.append(0) | |
segment_ids.append(0) | |
assert len(input_ids) == config.max_seq_len | |
assert len(input_mask) == config.max_seq_len | |
assert len(segment_ids) == config.max_seq_len | |
# -1 for [CLS] | |
noq_start_position -= 1 | |
noq_end_position -= 1 | |
# +2 for [CLS] and [SEP] | |
start_position = len(output_indices) + 2 + noq_start_position | |
end_position = len(output_indices) + 2 + noq_end_position | |
example_index = self.all_example_ids[i] | |
unique_id = self.all_unique_ids[i] | |
tokens = ["[CLS]"] | |
for q_token in decoded_words.split(" "): | |
tokens.append(q_token) | |
tokens.append("[SEP]") | |
# discard [CLS] but not [SEP] | |
c_tokens = self.all_c_tokens[i][1:] | |
for c_token in c_tokens: | |
tokens.append(c_token) | |
token_to_orig_map = self.all_token_to_orig_map[i] | |
token_is_max_context = self.all_token_is_max_context[i] | |
feature = EvalFeature(unique_id=unique_id, | |
tokens=tokens, | |
token_to_orig_map=token_to_orig_map, | |
token_is_max_context=token_is_max_context, | |
example_index=example_index, | |
input_ids=input_ids, | |
input_mask=input_mask, | |
segment_ids=segment_ids, | |
start_position=start_position, | |
end_position=end_position) | |
features.append(feature) | |
pred_fw.close() | |
golden_fw.close() | |
src_fw.close() | |
self.merge_files(self.total_file) | |
self.qa_eval(features, config.qa_path) | |
def qa_eval(self, eval_features, model_path): | |
eval_examples = read_squad_examples("./squad/new_test-v1.1.json", is_training=False, | |
debug=config.debug) | |
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) | |
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) | |
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) | |
all_example_index = torch.arange(all_input_ids.size(0)) | |
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) | |
eval_dataloader = DataLoader(eval_data, shuffle=False, batch_size=config.batch_size) | |
model = BertForQuestionAnswering.from_pretrained(model_path) | |
model = model.to(config.device) | |
model.eval() | |
all_results = [] | |
RawResult = collections.namedtuple("RawResult", | |
["unique_id", "start_logits", "end_logits"]) | |
for data in eval_dataloader: | |
input_ids, input_mask, segment_ids, example_indices = data | |
input_ids = input_ids.to(config.device) | |
input_mask = input_mask.to(config.device) | |
segment_ids = segment_ids.to(config.device) | |
with torch.no_grad(): | |
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask) | |
for i, example_index in enumerate(example_indices): | |
start_logits = batch_end_logits[i].detach().cpu().tolist() | |
end_logits = batch_end_logits[i].detach().cpu().tolist() | |
eval_feature = eval_features[example_index.item()] | |
unique_id = int(eval_feature.unique_id) | |
all_results.append(RawResult(unique_id=unique_id, | |
start_logits=start_logits, | |
end_logits=end_logits)) | |
output_prediction_file = os.path.join(self.output_dir, "predictions.json") | |
output_nbest_file = os.path.join(self.output_dir, "nbest_predictions.json") | |
output_null_log_odds_file = os.path.join(self.output_dir, "null_odds.json") | |
write_predictions(eval_examples, eval_features, all_results, | |
n_best_size=20, max_answer_length=30, do_lower_case=True, | |
output_prediction_file=output_prediction_file, | |
output_nbest_file=output_nbest_file, | |
output_null_log_odds_file=output_null_log_odds_file, | |
verbose_logging=False, | |
version_2_with_negative=False, | |
null_score_diff_threshold=0) | |
with open("./squad/new_test-v1.1.json") as dataset_file: | |
dataset_json = json.load(dataset_file) | |
dataset = dataset_json['data'] | |
with open(output_prediction_file) as prediction_file: | |
predictions = json.load(prediction_file) | |
print(json.dumps(evaluate(dataset, predictions))) | |
def beam_search(self, src_seq, src_len, tag_seq): | |
if config.use_gpu: | |
_seq = src_seq.to(config.device) | |
src_len = src_len.to(config.device) | |
if config.use_tag: | |
tag_seq = tag_seq.to(config.device) | |
# forward encoder | |
enc_outputs, enc_states = self.model.encoder(src_seq, src_len, tag_seq) | |
h, c = enc_states # [2, b, d] but b = 1 | |
hypotheses = [Hypothesis(tokens=[self.tok2idx["[CLS]"]], | |
log_probs=[0.0], | |
state=(h[:, 0, :], c[:, 0, :]), | |
context=None) for _ in range(config.beam_size)] | |
# tile enc_outputs, enc_mask for beam search | |
ext_src_seq = src_seq.repeat(config.beam_size, 1) | |
enc_outputs = enc_outputs.repeat(config.beam_size, 1, 1) | |
zeros = enc_outputs.sum(dim=-1) | |
enc_mask = (zeros == 0).byte() | |
enc_features = self.model.decoder.get_encoder_features(enc_outputs) | |
num_steps = 0 | |
results = [] | |
while num_steps < config.max_decode_step and len(results) < config.beam_size: | |
latest_tokens = [h.latest_token for h in hypotheses] | |
prev_y = torch.LongTensor(latest_tokens).view(-1) | |
if config.use_gpu: | |
prev_y = prev_y.to(config.device) | |
# make batch of which size is beam size | |
all_state_h = [] | |
all_state_c = [] | |
for h in hypotheses: | |
state_h, state_c = h.state # [num_layers, d] | |
all_state_h.append(state_h) | |
all_state_c.append(state_c) | |
prev_h = torch.stack(all_state_h, dim=1) # [num_layers, beam, d] | |
prev_c = torch.stack(all_state_c, dim=1) # [num_layers, beam, d] | |
prev_states = (prev_h, prev_c) | |
# [beam_size, |V|] | |
logits, states, = self.model.decoder.decode(prev_y, | |
ext_src_seq, | |
prev_states, | |
enc_features, | |
enc_mask) | |
h_state, c_state = states | |
log_probs = F.log_softmax(logits, dim=1) | |
top_k_log_probs, top_k_ids \ | |
= torch.topk(log_probs, config.beam_size * 2, dim=-1) | |
all_hypotheses = [] | |
num_orig_hypotheses = 1 if num_steps == 0 else len(hypotheses) | |
for i in range(num_orig_hypotheses): | |
h = hypotheses[i] | |
state_i = (h_state[:, i, :], c_state[:, i, :]) | |
for j in range(config.beam_size * 2): | |
new_h = h.extend(token=top_k_ids[i][j].item(), | |
log_prob=top_k_log_probs[i][j].item(), | |
state=state_i, | |
context=None) | |
all_hypotheses.append(new_h) | |
hypotheses = [] | |
for h in self.sort_hypotheses(all_hypotheses): | |
if h.latest_token == self.tok2idx["[SEP]"]: | |
if num_steps >= config.min_decode_step: | |
results.append(h) | |
else: | |
hypotheses.append(h) | |
if len(hypotheses) == config.beam_size or len(results) == config.beam_size: | |
break | |
num_steps += 1 | |
if len(results) == 0: | |
results = hypotheses | |
h_sorted = self.sort_hypotheses(results) | |
return h_sorted[0] | |
def merge_files(self, output_file): | |
all_c_tokens = open(self.src_file, "r").readlines() | |
all_answer_text = open(self.ans_file, "r").readlines() | |
all_pred_q = open(self.pred_file, "r").readlines() | |
all_golden_q = open(self.golden_file, "r").readlines() | |
data = zip(all_c_tokens, all_answer_text, all_pred_q, all_golden_q) | |
with open(output_file, "w") as f: | |
for c_token, answer, pred_q, golden_q in data: | |
line = pred_q.strip() + "\t" + golden_q.strip() + "\t" + c_token.strip() + "\t" + answer.strip() + "\n" | |
f.write(line) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment