Skip to content

Instantly share code, notes, and snippets.

@seanie12
Created May 14, 2019 01:06
Show Gist options
  • Save seanie12/b533ea61d55ffaae87c045c5b1efb53e to your computer and use it in GitHub Desktop.
Save seanie12/b533ea61d55ffaae87c045c5b1efb53e to your computer and use it in GitHub Desktop.
from model import Seq2seq
import os
from squad_utils import read_squad_examples, convert_examples_to_features, write_predictions
from pytorch_pretrained_bert import BertTokenizer, BertForQuestionAnswering
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import config
import collections
import re, string, sys, json
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(dataset, predictions):
f1 = exact_match = total = 0
for article in dataset:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
total += 1
if qa['id'] not in predictions:
message = 'Unanswered question ' + qa['id'] + \
' will receive score 0.'
print(message, file=sys.stderr)
continue
ground_truths = list(map(lambda x: x['text'], qa['answers']))
prediction = predictions[qa['id']]
exact_match += metric_max_over_ground_truths(
exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(
f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {'exact_match': exact_match, 'f1': f1}
class Hypothesis(object):
def __init__(self, tokens, log_probs, state, context=None):
self.tokens = tokens
self.log_probs = log_probs
self.state = state
self.context = context
def extend(self, token, log_prob, state, context=None):
h = Hypothesis(tokens=self.tokens + [token],
log_probs=self.log_probs + [log_prob],
state=state,
context=context)
return h
@property
def latest_token(self):
return self.tokens[-1]
@property
def avg_log_prob(self):
return sum(self.log_probs) / len(self.tokens)
class EvalFeature(object):
def __init__(self,
unique_id,
example_index,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None):
self.unique_id = unique_id
self.example_index = example_index
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
class BeamSearcher(object):
def __init__(self, model_path, output_dir):
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.output_dir = output_dir
self.golden_q_ids = None
self.all_c_tokens = None
self.all_answer_text = None
self.all_example_ids = None
self.all_unique_ids = None
self.all_c_tokens = None
self.all_token_to_orig_map = None
self.all_token_is_max_context = None
self.data_loader = self.get_data_loader("./squad/new_test-v1.1.json")
self.tok2idx = self.tokenizer.vocab
self.idx2tok = {idx: tok for tok, idx in self.tok2idx.items()}
self.model = Seq2seq(dropout=0.0, model_path=model_path, use_tag=config.use_tag)
self.model.requires_grad = False
self.model.eval_mode()
self.src_file = output_dir + "/src.txt"
self.pred_file = output_dir + "/generated.txt"
self.golden_file = output_dir + "/golden.txt"
self.ans_file = output_dir + "/answer.txt"
self.total_file = output_dir + "/all_files.csv"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def get_data_loader(self, file):
train_examples = read_squad_examples(file, is_training=True, debug=config.debug)
train_features = convert_examples_to_features(train_examples,
tokenizer=self.tokenizer,
max_seq_length=config.max_seq_len,
max_query_length=config.max_query_len,
doc_stride=128,
is_training=True)
all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long)
all_c_lens = torch.sum(torch.sign(all_c_ids), 1)
all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long)
all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long)
all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long)
all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long)
# for qa eval
self.all_example_ids = [f.example_index for f in train_features]
self.all_unique_ids = [f.unique_id for f in train_features]
self.all_c_tokens = [f.context_tokens for f in train_features]
self.all_token_is_max_context = [f.token_is_max_context for f in train_features]
self.all_token_to_orig_map = [f.token_to_orig_map for f in train_features]
train_data = TensorDataset(all_c_ids, all_c_lens, all_tag_ids, all_q_ids,
all_noq_start_positions, all_noq_end_positions)
train_loader = DataLoader(train_data, shuffle=False, batch_size=1)
self.all_answer_text = [f.answer_text for f in train_features]
self.golden_q_ids = all_q_ids
return train_loader
@staticmethod
def sort_hypotheses(hypotheses):
return sorted(hypotheses, key=lambda h: h.avg_log_prob, reverse=True)
def decode(self):
pred_fw = open(self.pred_file, "w")
golden_fw = open(self.golden_file, "w")
src_fw = open(self.src_file, "w")
ans_fw = open(self.ans_file, "w")
features = []
for i, eval_data in enumerate(self.data_loader):
c_ids, c_lens, tag_seq, q_ids, noq_start_position, noq_end_position = eval_data
c_ids = c_ids.to(config.device)
c_lens = c_lens.to(config.device)
tag_seq = tag_seq.to(config.device)
if config.use_tag is False:
tag_seq = None
best_question = self.beam_search(c_ids, c_lens, tag_seq)
# discard START token
output_indices = [int(idx) for idx in best_question.tokens[1:]]
sep_idx = self.tokenizer.vocab["[SEP]"]
try:
fst_stop_idx = output_indices.index(sep_idx)
output_indices = output_indices[:fst_stop_idx]
except ValueError:
output_indices = output_indices
decoded_words = self.tokenizer.convert_ids_to_tokens(output_indices)
decoded_words = " ".join(decoded_words)
q_id = self.golden_q_ids[i]
q_len = torch.sum(torch.sign(q_ids), 1).item()
# discard [CLS], [SEP] and unnecessary PAD tokens
q_id = q_id[1:q_len - 1].cpu().numpy()
golden_question = self.tokenizer.convert_ids_to_tokens(q_id)
answer_text = self.all_answer_text[i]
# de-tokenize src tokens
src_tokens = self.all_c_tokens[i]
# discard [CLS] and [SEP] tokens
src_txt = " ".join(src_tokens[1:-1])
src_txt = src_txt.replace(" ##", "")
src_txt = src_txt.replace("##", "").strip()
print("write {}th question".format(i))
pred_fw.write(decoded_words + "\n")
golden_fw.write(" ".join(golden_question) + "\n")
src_fw.write(src_txt + "\n")
ans_fw.write(answer_text + "\n")
# construct eval examples
cls_id = self.tokenizer.vocab["[CLS]"]
input_ids = list()
segment_ids = list()
input_ids.append(cls_id)
segment_ids.append(0)
for q_id in output_indices:
input_ids.append(q_id)
segment_ids.append(0)
input_ids.append(sep_idx)
segment_ids.append(0)
# exclude [cls] but not [sep]
c_ids = c_ids.cpu().tolist()[1:]
for c_id in c_ids:
input_ids.append(c_id)
segment_ids.append(1)
input_mask = [1] * len(input_ids)
while len(input_ids) < config.max_seq_len:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == config.max_seq_len
assert len(input_mask) == config.max_seq_len
assert len(segment_ids) == config.max_seq_len
# -1 for [CLS]
noq_start_position -= 1
noq_end_position -= 1
# +2 for [CLS] and [SEP]
start_position = len(output_indices) + 2 + noq_start_position
end_position = len(output_indices) + 2 + noq_end_position
example_index = self.all_example_ids[i]
unique_id = self.all_unique_ids[i]
tokens = ["[CLS]"]
for q_token in decoded_words.split(" "):
tokens.append(q_token)
tokens.append("[SEP]")
# discard [CLS] but not [SEP]
c_tokens = self.all_c_tokens[i][1:]
for c_token in c_tokens:
tokens.append(c_token)
token_to_orig_map = self.all_token_to_orig_map[i]
token_is_max_context = self.all_token_is_max_context[i]
feature = EvalFeature(unique_id=unique_id,
tokens=tokens,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
example_index=example_index,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
start_position=start_position,
end_position=end_position)
features.append(feature)
pred_fw.close()
golden_fw.close()
src_fw.close()
self.merge_files(self.total_file)
self.qa_eval(features, config.qa_path)
def qa_eval(self, eval_features, model_path):
eval_examples = read_squad_examples("./squad/new_test-v1.1.json", is_training=False,
debug=config.debug)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0))
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
eval_dataloader = DataLoader(eval_data, shuffle=False, batch_size=config.batch_size)
model = BertForQuestionAnswering.from_pretrained(model_path)
model = model.to(config.device)
model.eval()
all_results = []
RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"])
for data in eval_dataloader:
input_ids, input_mask, segment_ids, example_indices = data
input_ids = input_ids.to(config.device)
input_mask = input_mask.to(config.device)
segment_ids = segment_ids.to(config.device)
with torch.no_grad():
batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
for i, example_index in enumerate(example_indices):
start_logits = batch_end_logits[i].detach().cpu().tolist()
end_logits = batch_end_logits[i].detach().cpu().tolist()
eval_feature = eval_features[example_index.item()]
unique_id = int(eval_feature.unique_id)
all_results.append(RawResult(unique_id=unique_id,
start_logits=start_logits,
end_logits=end_logits))
output_prediction_file = os.path.join(self.output_dir, "predictions.json")
output_nbest_file = os.path.join(self.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(self.output_dir, "null_odds.json")
write_predictions(eval_examples, eval_features, all_results,
n_best_size=20, max_answer_length=30, do_lower_case=True,
output_prediction_file=output_prediction_file,
output_nbest_file=output_nbest_file,
output_null_log_odds_file=output_null_log_odds_file,
verbose_logging=False,
version_2_with_negative=False,
null_score_diff_threshold=0)
with open("./squad/new_test-v1.1.json") as dataset_file:
dataset_json = json.load(dataset_file)
dataset = dataset_json['data']
with open(output_prediction_file) as prediction_file:
predictions = json.load(prediction_file)
print(json.dumps(evaluate(dataset, predictions)))
def beam_search(self, src_seq, src_len, tag_seq):
if config.use_gpu:
_seq = src_seq.to(config.device)
src_len = src_len.to(config.device)
if config.use_tag:
tag_seq = tag_seq.to(config.device)
# forward encoder
enc_outputs, enc_states = self.model.encoder(src_seq, src_len, tag_seq)
h, c = enc_states # [2, b, d] but b = 1
hypotheses = [Hypothesis(tokens=[self.tok2idx["[CLS]"]],
log_probs=[0.0],
state=(h[:, 0, :], c[:, 0, :]),
context=None) for _ in range(config.beam_size)]
# tile enc_outputs, enc_mask for beam search
ext_src_seq = src_seq.repeat(config.beam_size, 1)
enc_outputs = enc_outputs.repeat(config.beam_size, 1, 1)
zeros = enc_outputs.sum(dim=-1)
enc_mask = (zeros == 0).byte()
enc_features = self.model.decoder.get_encoder_features(enc_outputs)
num_steps = 0
results = []
while num_steps < config.max_decode_step and len(results) < config.beam_size:
latest_tokens = [h.latest_token for h in hypotheses]
prev_y = torch.LongTensor(latest_tokens).view(-1)
if config.use_gpu:
prev_y = prev_y.to(config.device)
# make batch of which size is beam size
all_state_h = []
all_state_c = []
for h in hypotheses:
state_h, state_c = h.state # [num_layers, d]
all_state_h.append(state_h)
all_state_c.append(state_c)
prev_h = torch.stack(all_state_h, dim=1) # [num_layers, beam, d]
prev_c = torch.stack(all_state_c, dim=1) # [num_layers, beam, d]
prev_states = (prev_h, prev_c)
# [beam_size, |V|]
logits, states, = self.model.decoder.decode(prev_y,
ext_src_seq,
prev_states,
enc_features,
enc_mask)
h_state, c_state = states
log_probs = F.log_softmax(logits, dim=1)
top_k_log_probs, top_k_ids \
= torch.topk(log_probs, config.beam_size * 2, dim=-1)
all_hypotheses = []
num_orig_hypotheses = 1 if num_steps == 0 else len(hypotheses)
for i in range(num_orig_hypotheses):
h = hypotheses[i]
state_i = (h_state[:, i, :], c_state[:, i, :])
for j in range(config.beam_size * 2):
new_h = h.extend(token=top_k_ids[i][j].item(),
log_prob=top_k_log_probs[i][j].item(),
state=state_i,
context=None)
all_hypotheses.append(new_h)
hypotheses = []
for h in self.sort_hypotheses(all_hypotheses):
if h.latest_token == self.tok2idx["[SEP]"]:
if num_steps >= config.min_decode_step:
results.append(h)
else:
hypotheses.append(h)
if len(hypotheses) == config.beam_size or len(results) == config.beam_size:
break
num_steps += 1
if len(results) == 0:
results = hypotheses
h_sorted = self.sort_hypotheses(results)
return h_sorted[0]
def merge_files(self, output_file):
all_c_tokens = open(self.src_file, "r").readlines()
all_answer_text = open(self.ans_file, "r").readlines()
all_pred_q = open(self.pred_file, "r").readlines()
all_golden_q = open(self.golden_file, "r").readlines()
data = zip(all_c_tokens, all_answer_text, all_pred_q, all_golden_q)
with open(output_file, "w") as f:
for c_token, answer, pred_q, golden_q in data:
line = pred_q.strip() + "\t" + golden_q.strip() + "\t" + c_token.strip() + "\t" + answer.strip() + "\n"
f.write(line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment