Created
September 3, 2019 11:57
-
-
Save seanie12/a41d209153c90f63350344f225a19d21 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RLTrainer(CatTrainer): | |
def __init__(self, args): | |
super(RLTrainer, self).__init__(args) | |
self.qa_model = BiDAF(embedding_size=100, | |
vocab_size=self.vocab_size, | |
hidden_size=args.qa_hidden_size, | |
drop_prob=0.2) | |
state_dict = torch.load(args.qa_file, map_location="cpu") | |
self.qa_model.load_state_dict(state_dict) | |
self.qa_model = self.qa_model.to(self.device) | |
self.scheduler = lr_scheduler.LambdaLR(self.opt, lambda s: 1.) | |
self.ema = EMA(self.qa_model, 0.999) | |
params = self.qa_model.parameters() | |
self.qa_opt = optim.Adam(params, self.args.qa_lr) | |
pi_params = self.model.prior_encoder.parameters() | |
self.pi_opt = optim.Adam(pi_params, self.args.lr) | |
def init_model(self, args): | |
# QG model | |
sos_id = self.tokenizer.vocab["[CLS]"] | |
eos_id = self.tokenizer.vocab["[SEP]"] | |
model = DiscreteVAE(padding_idx=0, | |
sos_id=sos_id, | |
eos_id=eos_id, | |
bert_model="bert-base-uncased", | |
ntokens=len(self.tokenizer.vocab), | |
nhidden=512, | |
nlayers=1, | |
dropout=0.2, | |
nz=20, | |
nzdim=10, | |
freeze=self.args.freeze, | |
copy=True) | |
model = model.to(self.device) | |
state_dict = torch.load(self.args.qg_file, map_location="cpu") | |
model.load_state_dict(state_dict) | |
return model | |
def get_opt(self): | |
params = self.model.parameters() | |
opt = optim.Adam(params, self.args.lr) | |
return opt | |
def process_batch(self, batch): | |
batch = tuple(t.to(self.device) for t in batch) | |
q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions = batch | |
q_len = torch.sum(torch.sign(q_ids), 1) | |
max_len = torch.max(q_len) | |
q_ids = q_ids[:, :max_len] | |
c_len = torch.sum(torch.sign(c_ids), 1) | |
max_len = torch.max(c_len) | |
c_ids = c_ids[:, :max_len] | |
tag_ids = tag_ids[:, :max_len] | |
a_len = torch.sum(torch.sign(ans_ids), 1) | |
max_len = torch.max(a_len) | |
ans_ids = ans_ids[:, :max_len] | |
return q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions | |
def train(self): | |
batch_num = len(self.train_loader) | |
global_step = 1 | |
avg_qa_loss = 0 | |
avg_kl = 0 | |
avg_adv_loss = 0 | |
kl_div = nn.KLDivLoss(reduction="batchmean") | |
best_f1 = 0 | |
for epoch in range(1, self.args.num_epochs + 1): | |
start = time.time() | |
self.model.train() | |
self.qa_model.train() | |
for step, batch in enumerate(self.train_loader, start=1): | |
# allocate tensors to device | |
q_ids, c_ids, tag_ids, _, \ | |
start_positions, end_positions = self.process_batch(batch) | |
ans_ids = (tag_ids != 0).long() | |
real_loss = self.qa_model(c_ids, q_ids, | |
start_positions=start_positions, | |
end_positions=end_positions) | |
# generate question from prior | |
with torch.no_grad(): | |
# sample z from prior distribution | |
prior_z_logits, _ = self.model.prior_encoder(c_ids) | |
# sample (q, a) | |
gen_q_ids, gen_start_positions, gen_end_positions, \ | |
_, _ = self.model.generate(prior_z_logits, c_ids) | |
# forward qa model for generated question | |
fake_loss = self.qa_model(c_ids, gen_q_ids, | |
start_positions=gen_start_positions, | |
end_positions=gen_end_positions) | |
# loss for generated question | |
qa_loss = real_loss + self.args.adv_lambda * fake_loss | |
self.qa_opt.zero_grad() | |
qa_loss.backward() | |
nn.utils.clip_grad_norm_(self.qa_model.parameters(), 5.0) | |
self.qa_opt.step() | |
self.scheduler.step() | |
self.ema(self.qa_model, global_step) | |
global_step += 1 | |
# sample z from prior | |
prior_z_logits, prior_z_probs = self.model.prior_encoder(c_ids) | |
flatten_prior_logits = prior_z_logits.view(-1, self.args.num_classes) | |
log_prob_prior = F.log_softmax(flatten_prior_logits, dim=1) | |
# sample z from posterior | |
with torch.no_grad(): | |
posterior_z_logits, posterior_z_prob = self.model.posterior_encoder(c_ids, q_ids, ans_ids) | |
flatten_posterior = posterior_z_logits.view(-1, self.args.num_classes) | |
# regularization with kl-divergence | |
prob_posterior = F.softmax(flatten_posterior, dim=1) | |
kl = kl_div(log_prob_prior, prob_posterior.detach()) | |
with torch.no_grad(): | |
gen_q_ids, gen_start_positions, gen_end_positions, \ | |
latent_z = self.model.sample(prior_z_logits, c_ids) | |
# reward is qa loss, so pi maximizes qa loss | |
start_logits, end_logits = self.qa_model(c_ids, gen_q_ids) | |
reward = self.get_reward(start_logits, end_logits, | |
gen_start_positions, gen_end_positions) | |
action_probs = torch.sum(prior_z_probs * latent_z, dim=-1) | |
log_prob = torch.log(action_probs + 1e-12) # [b,num_vars] | |
adv_loss = -(reward.unsqueeze(1).detach() * log_prob).sum(1).mean() | |
# backward pass | |
pi_loss = adv_loss + kl | |
pi_loss.backward() | |
self.pi_opt.step() | |
self.pi_opt.zero_grad() | |
avg_qa_loss = cal_running_avg_loss(qa_loss.item(), avg_qa_loss) | |
avg_kl = cal_running_avg_loss(kl.item(), avg_kl) | |
avg_adv_loss = cal_running_avg_loss(adv_loss.item(), avg_adv_loss) | |
msg = "{}/{} {} - ETA : {} - QA loss: {:.4f}, KL: {:.4f}, adv loss: {:.4f}" \ | |
.format(step, batch_num, progress_bar(step, batch_num), | |
eta(start, step, batch_num), avg_qa_loss, avg_kl, avg_adv_loss) | |
print(msg, end="\r") | |
if not self.args.debug: | |
result_dict = self.eval(msg) | |
f1 = result_dict["f1"] | |
em = result_dict["exact_match"] | |
print("Epoch {} took {} - F1: {:.4f}, EM: {:.4f}," | |
.format(epoch, user_friendly_time(time_since(start)), f1, em)) | |
if f1 > best_f1: | |
best_f1 = f1 | |
self.save_qa_model(epoch, f1, em) | |
self.save_model(epoch, f1) | |
@staticmethod | |
def compute_fake_loss(start_logits, end_logits, context_len, context_mask): | |
batch_size, time_step = start_logits.size() | |
uniform_dist = torch.ones(batch_size, device=context_len.device) / context_len.float() | |
uniform_dist = uniform_dist.unsqueeze(1) | |
uniform_dist = uniform_dist.repeat([1, time_step]).masked_fill(context_mask, 0) | |
kl_div = nn.KLDivLoss(reduction="batchmean") | |
start_log_prob = F.log_softmax(start_logits, dim=1) | |
end_log_prob = F.log_softmax(end_logits, dim=1) | |
start_loss = kl_div(start_log_prob, uniform_dist) | |
end_loss = kl_div(end_log_prob, uniform_dist) | |
loss = 0.5 * (start_loss + end_loss) | |
return loss | |
@staticmethod | |
def get_seq_len(input_ids, eos_id): | |
# input_ids: [b, t] | |
# eos_id : scalar | |
mask = (input_ids == eos_id).byte() | |
num_eos = torch.sum(mask, 1) | |
# change Tensor to cpu because torch.argmax works differently in cuda and cpu | |
# but np.argmax is consistent it returns the first index of the maximum element | |
mask = mask.cpu().numpy() | |
indices = np.argmax(mask, 1) | |
# convert numpy array to Tensor | |
seq_len = torch.LongTensor(indices).to(input_ids.device) | |
# in case there is no eos in the sequence | |
max_len = input_ids.size(1) | |
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1) | |
# +1 for eos | |
seq_len = seq_len + 1 | |
return seq_len | |
@staticmethod | |
def get_reward(start_logits, end_logits, start_positions, end_positions): | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions = start_positions.clamp(0, ignored_index) | |
end_positions = end_positions.clamp(0, ignored_index) | |
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index, reduction="none") | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
return total_loss | |
@staticmethod | |
def post_process(q_ids, q_len, c_ids, cls_id): | |
batch_size = q_ids.size(0) | |
max_q_len = torch.max(q_len) | |
cls_ids = cls_id * torch.ones((batch_size, 1), device=q_ids.device, dtype=torch.long) | |
all_input_ids = [] | |
all_seg_ids = [] | |
for i in range(batch_size): | |
q_length = q_len[i] | |
q = q_ids[i, :q_length] # exclude pad tokens | |
c = c_ids[i, 1:] # exclude [CLS] | |
# input ids | |
pads = torch.zeros((max_q_len - q_length), device=q_ids.device, dtype=torch.long) | |
input_ids = torch.cat([q, c, pads], dim=0) | |
all_input_ids.append(input_ids) | |
# segment ids | |
zeros = torch.zeros_like(q) | |
ones = torch.ones_like(c) | |
seg_ids = torch.cat([zeros, ones, pads], dim=0) | |
all_seg_ids.append(seg_ids) | |
all_input_ids = torch.stack(all_input_ids, dim=0) | |
all_input_ids = torch.cat([cls_ids, all_input_ids], dim=1) | |
# segment id for cls | |
zeros = torch.zeros_like(cls_ids) | |
all_seg_ids = torch.stack(all_seg_ids, dim=0) | |
all_seg_ids = torch.cat([zeros, all_seg_ids], dim=1) | |
# attention mask | |
mask = (all_input_ids == 0).byte() | |
all_seg_ids = all_seg_ids.masked_fill(mask, 0) | |
attention_mask = 1 - mask | |
return all_input_ids, all_seg_ids, attention_mask | |
def save_qa_model(self, epoch, f1, em): | |
f1 = round(f1, 2) | |
em = round(em, 2) | |
save_file = os.path.join(self.save_dir, "{}_{:.2f}_{:.2f}".format(epoch, f1, em)) | |
state_dict = self.qa_model.state_dict() | |
torch.save(state_dict, save_file) | |
def eval(self, msg): | |
self.ema.assign(self.qa_model) | |
self.qa_model.eval() | |
all_results = [] | |
example_index = -1 | |
num_val_batches = len(self.dev_loader) | |
RawResult = collections.namedtuple("RawResult", | |
["unique_id", "start_logits", "end_logits"]) | |
for i, batch in enumerate(self.dev_loader, start=1): | |
q_ids, c_ids, tag_ids, ans_ids, _, _ = self.process_batch(batch) | |
with torch.no_grad(): | |
batch_start_logits, batch_end_logits = self.qa_model(c_ids, q_ids) | |
batch_size = batch_end_logits.size(0) | |
for j in range(batch_size): | |
example_index += 1 | |
start_logits = batch_start_logits[j].detach().cpu().tolist() | |
end_logits = batch_end_logits[j].detach().cpu().tolist() | |
eval_feature = self.eval_features[example_index] | |
unique_id = int(eval_feature.unique_id) | |
all_results.append(RawResult(unique_id=unique_id, | |
start_logits=start_logits, | |
end_logits=end_logits)) | |
msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches) | |
print(msg2, end="\r") | |
output_prediction_file = os.path.join(self.save_dir, "adv_prediction.json") | |
write_predictions(self.eval_examples, self.eval_features, all_results, | |
n_best_size=20, max_answer_length=30, do_lower_case=True, | |
output_prediction_file=output_prediction_file, | |
verbose_logging=False, | |
version_2_with_negative=False, | |
null_score_diff_threshold=0, | |
noq_position=True) | |
with open(self.args.dev_file) as f: | |
data_json = json.load(f) | |
dataset = data_json["data"] | |
with open(output_prediction_file) as prediction_file: | |
predictions = json.load(prediction_file) | |
results = evaluate(dataset, predictions) | |
self.qa_model.train() | |
self.ema.resume(self.qa_model) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment