Skip to content

Instantly share code, notes, and snippets.

@seanie12
Last active May 9, 2019 02:24
Show Gist options
  • Save seanie12/3320e71d4abe080ddbab04d2ce5772fa to your computer and use it in GitHub Desktop.
Save seanie12/3320e71d4abe080ddbab04d2ce5772fa to your computer and use it in GitHub Desktop.
import torch
from pytorch_pretrained_bert import BertTokenizer
import random
import numpy as np
from squad_utils import convert_examples_to_features, read_squad_examples
import config
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
train_examples = read_squad_examples("./squad/train-v1.1.json", is_training=True, debug=False)
train_features = convert_examples_to_features(train_examples, tokenizer=tokenizer,
max_seq_length=config.max_seq_len, doc_stride=128,
max_query_length=config.max_query_len, is_training=True)
all_c_ids = torch.tensor([f.c_ids for f in train_features], dtype=torch.long)
all_c_lens = torch.sum(torch.sign(all_c_ids), 1)
all_tag_ids = torch.tensor([f.tag_ids for f in train_features], dtype=torch.long)
all_q_ids = torch.tensor([f.q_ids for f in train_features], dtype=torch.long)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
all_noq_start_positions = torch.tensor([f.noq_start_position for f in train_features], dtype=torch.long)
all_noq_end_positions = torch.tensor([f.noq_end_position for f in train_features], dtype=torch.long)
all_context_tokens = [f.context_tokens for f in train_features]
all_answer_text = [f.answer_text for f in train_features]
all_q_tokens = [f.q_tokens for f in train_features]
for _ in range(10):
idx = random.randint(0, len(all_context_tokens) - 1)
context = all_context_tokens[idx]
q = all_q_tokens[idx]
start = all_noq_start_positions[idx]
end = all_noq_end_positions[idx]
answer = all_answer_text[idx]
tag_ids = all_tag_ids[idx].tolist()
tag_len = np.sum(np.sign(tag_ids))
begin_idx = tag_ids.index(1)
print("question:", q)
print("passage:", context)
print("answer :", answer)
print("extracted:", context[start: end + 1])
print("extracted:", context[begin_idx: begin_idx + tag_len])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment