Last active
September 3, 2019 08:36
-
-
Save seanie12/2141ce3a7d91c34811948e95dafaecf0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pytorch_pretrained_bert import BertModel | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
from torch_scatter import scatter_max | |
def return_mask_lengths(ids): | |
if ids.dim() == 3: # it means it is one hot | |
mask = torch.sum(ids, dim=2) | |
else: | |
mask = torch.sign(ids).long() | |
lengths = mask.sum(dim=1) | |
return mask, lengths | |
def cal_attn(left, right, mask): | |
mask = (1.0 - mask.float()) * -10000.0 | |
attn_logits = torch.matmul(left, right.transpose(-1, -2).contiguous()) | |
attn_logits = attn_logits + mask | |
attn_weights = F.softmax(input=attn_logits, dim=-1) | |
attn_outputs = torch.matmul(attn_weights, right) | |
return attn_outputs, attn_logits | |
class BertEmbedding(nn.Module): | |
def __init__(self, bert_model): | |
super(BertEmbedding, self).__init__() | |
bert_embedding = BertModel.from_pretrained(bert_model).embeddings | |
self.word_embeddings = bert_embedding.word_embeddings | |
self.position_embeddings = bert_embedding.position_embeddings | |
self.token_type_embeddings = bert_embedding.token_type_embeddings | |
self.LayerNorm = bert_embedding.LayerNorm | |
self.dropout = bert_embedding.dropout | |
def forward(self, input_ids, token_type_ids=None, position_ids=None): | |
if input_ids.dim() == 3: | |
word_embeddings = F.linear(input_ids, self.word_embeddings.weight.transpose(-1, -2).contiguous()) | |
input_size = input_ids[:, :, 0].size() | |
else: | |
word_embeddings = self.word_embeddings(input_ids) | |
input_size = input_ids.size() | |
if position_ids is None: | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand(input_size) | |
if token_type_ids is None: | |
token_type_ids = torch.zeros(input_size).to(input_ids.device).long() | |
position_embeddings = self.position_embeddings(position_ids) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
embeddings = word_embeddings + position_embeddings + token_type_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1): | |
# type: (Tensor, float, bool, float, int) -> Tensor | |
# gumbels = -(torch.empty_like(logits).exponential_() + eps).log() # ~Gumbel(0,1) | |
gumbels = -(-(torch.rand_like(logits) + eps).log() + eps).log() | |
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) | |
y_soft = gumbels.softmax(dim) | |
if hard: | |
# Straight through. | |
index = y_soft.max(dim, keepdim=True)[1] | |
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) | |
ret = y_hard - y_soft.detach() + y_soft | |
else: | |
# Reparametrization trick. | |
ret = y_soft | |
return ret | |
class CatKLLoss(nn.Module): | |
def __init__(self): | |
super(CatKLLoss, self).__init__() | |
def forward(self, log_qy, log_py): | |
qy = torch.exp(log_qy) | |
kl = torch.sum(qy * (log_qy - log_py), dim=-1) | |
return torch.sum(kl, dim=-1) | |
class CustomLSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False): | |
super(CustomLSTM, self).__init__() | |
self.num_layers = num_layers | |
self.hidden_size = hidden_size | |
self.bidirectional = bidirectional | |
self.dropout = nn.Dropout(dropout) | |
if dropout > 0.0 and num_layers == 1: | |
dropout = 0.0 | |
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, | |
num_layers=num_layers, dropout=dropout, | |
bidirectional=bidirectional, batch_first=True) | |
def forward(self, input, input_lengths, state=None): | |
batch_size, total_length, _ = input.size() | |
input_packed = pack_padded_sequence(input, input_lengths, | |
batch_first=True, enforce_sorted=False) | |
self.lstm.flatten_parameters() | |
output_packed, state = self.lstm(input_packed, state) | |
output, _ = pad_packed_sequence(output_packed, batch_first=True, total_length=total_length) | |
output = self.dropout(output) | |
return output, state | |
class PosteriorEncoder(nn.Module): | |
def __init__(self, embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, | |
nz, nzdim, | |
dropout=0, freeze=False): | |
super(PosteriorEncoder, self).__init__() | |
self.nhidden = nhidden | |
self.ntokens = ntokens | |
self.nlayers = nlayers | |
self.nz = nz | |
self.nzdim = nzdim | |
if embedding is not None: | |
self.embedding = embedding | |
else: | |
self.embedding = BertEmbedding(bert_model) | |
self.question_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.question_linear = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.context_answer_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.context_answer_linear = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.posterior_linear = nn.Linear(2 * 4 * nhidden, nz * nzdim) | |
def forward(self, c_ids, q_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
q_mask, q_lengths = return_mask_lengths(q_ids) | |
# question enc | |
q_embeddings = self.embedding(q_ids) | |
q_hs, q_state = self.question_encoder(q_embeddings, q_lengths) | |
q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
# answer enc | |
c_a_embeddings = self.embedding(c_ids, a_ids, None) | |
c_a_hs, c_a_state = self.context_answer_encoder(c_a_embeddings, c_lengths) | |
c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
mask = q_mask.unsqueeze(1) | |
q_attned_by_ca, _ = cal_attn(self.question_linear(c_a_h).unsqueeze(1), q_hs, mask) | |
q_attned_by_ca = q_attned_by_ca.squeeze(1) | |
mask = c_mask.unsqueeze(1) | |
ca_attned_by_q, _ = cal_attn(self.context_answer_linear(q_h).unsqueeze(1), c_a_hs, mask) | |
ca_attned_by_q = ca_attned_by_q.squeeze(1) | |
h = torch.cat([q_h, q_attned_by_ca, c_a_h, ca_attned_by_q], dim=-1) | |
posterior_z_logits = self.posterior_linear(h).view(-1, self.nz, self.nzdim).contiguous() | |
posterior_z_prob = F.softmax(posterior_z_logits, dim=-1) | |
return posterior_z_logits, posterior_z_prob | |
class PriorEncoder(nn.Module): | |
def __init__(self, embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, | |
nz, nzdim, | |
dropout=0): | |
super(PriorEncoder, self).__init__() | |
self.nhidden = nhidden | |
self.ntokens = ntokens | |
self.nlayers = nlayers | |
self.nz = nz | |
self.nzdim = nzdim | |
if embedding is not None: | |
self.embedding = embedding | |
else: | |
self.embedding = BertEmbedding(bert_model) | |
self.context_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.prior_linear = nn.Linear(2 * nhidden, nz * nzdim) | |
def forward(self, c_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
# answer enc | |
c_embeddings = self.embedding(c_ids) | |
_, c_state = self.context_encoder(c_embeddings, c_lengths) | |
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
prior_z_logits = self.prior_linear(h).view(-1, self.nz, self.nzdim) | |
prior_z_prob = F.softmax(prior_z_logits, dim=-1) | |
return prior_z_logits, prior_z_prob | |
class AnswerDecoder(nn.Module): | |
def __init__(self, embedding, bert_model, emsize, | |
nhidden, nlayers, | |
dropout=0): | |
super(AnswerDecoder, self).__init__() | |
self.nhidden = nhidden = int(0.5 * nhidden) | |
self.nlayers = nlayers | |
if embedding is not None: | |
self.embedding = embedding | |
else: | |
self.embedding = BertEmbedding(bert_model) | |
self.context_lstm = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.start_lstm = CustomLSTM(input_size=4 * 2 * nhidden, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.end_lstm = CustomLSTM(input_size=2 * nhidden, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.start_linear = nn.Linear(4 * 2 * nhidden + 2 * nhidden, 1) | |
self.end_linear = nn.Linear(4 * 2 * nhidden + 2 * nhidden, 1) | |
def forward(self, init_state, c_ids): | |
batch_size, max_c_len = c_ids.size() | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_embeddings = self.embedding(c_ids) | |
H, _ = self.context_lstm(c_embeddings, c_lengths) | |
U = init_state.unsqueeze(1).repeat(1, max_c_len, 1) | |
G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1) | |
M1, _ = self.start_lstm(G, c_lengths) | |
M2, _ = self.end_lstm(M1, c_lengths) | |
start_logits = self.start_linear(torch.cat([G, M1], dim=-1)).squeeze(-1) | |
end_logits = self.end_linear(torch.cat([G, M2], dim=-1)).squeeze(-1) | |
start_end_mask = c_mask == 0 | |
masked_start_logits = start_logits.masked_fill(start_end_mask, -10000.0) | |
masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0) | |
return masked_start_logits, masked_end_logits | |
class ContextEncoderforQG(nn.Module): | |
def __init__(self, embedding, bert_model, emsize, | |
nhidden, nlayers, dropout=0): | |
super(ContextEncoderforQG, self).__init__() | |
if embedding is not None: | |
self.embedding = embedding | |
else: | |
self.embedding = BertEmbedding(bert_model) | |
self.context_lstm = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False) | |
self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False) | |
def forward(self, c_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_embeddings = self.embedding(c_ids, a_ids, None) | |
c_outputs, _ = self.context_lstm(c_embeddings, c_lengths) | |
# attention | |
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float()) | |
c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs), c_outputs, mask) | |
c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2) | |
c_fused = self.fusion(c_concat).tanh() | |
c_gate = self.gate(c_concat).sigmoid() | |
c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs | |
return c_outputs | |
class QuestionDecoder(nn.Module): | |
def __init__(self, sos_id, eos_id, | |
embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, | |
dropout=0, copy=True, max_q_len=64): | |
super(QuestionDecoder, self).__init__() | |
self.sos_id = sos_id | |
self.eos_id = eos_id | |
# this max_len include sos eos | |
self.max_q_len = max_q_len | |
self.nhidden = nhidden | |
self.ntokens = ntokens | |
self.nlayers = nlayers | |
self.copy = copy | |
if embedding is not None: | |
self.embedding = embedding | |
else: | |
self.embedding = BertEmbedding(bert_model) | |
self.context_lstm = ContextEncoderforQG(embedding, bert_model, emsize, | |
nhidden, nlayers, dropout) | |
self.question_lstm = CustomLSTM(input_size=emsize, | |
hidden_size=2 * nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=False) | |
self.question_linear = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.concat_linear = nn.Linear(4 * nhidden, 2 * nhidden) | |
self.logit_linear = nn.Linear(2 * nhidden, ntokens) | |
def forward(self, init_state, c_ids, q_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
q_mask, q_lengths = return_mask_lengths(q_ids) | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
batch_size, max_q_len = q_ids.size() | |
# question dec | |
q_embeddings = self.embedding(q_ids, None, torch.zeros_like(q_ids)) | |
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state) | |
# attention | |
mask = torch.matmul(q_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float()) | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), c_outputs, mask) | |
# gen logits | |
q_concat = self.concat_linear(torch.cat([q_outputs, c_attned_by_q], dim=2)).tanh() | |
logits = self.logit_linear(q_concat) | |
if self.copy: | |
# copy logits | |
bq = batch_size * max_q_len | |
c_ids = c_ids.unsqueeze(1).repeat(1, max_q_len, 1).view(bq, -1).contiguous() | |
attn_logits = attn_logits.view(bq, -1).contiguous() | |
out = torch.zeros(bq, self.ntokens).to(c_ids.device) | |
out = out - 10000.0 | |
out, _ = scatter_max(attn_logits, c_ids, out=out) | |
out = out.masked_fill(out == -10000.0, 0) | |
out = out.view(batch_size, max_q_len, -1).contiguous() | |
logits = logits + out | |
return logits | |
def generate(self, init_state, c_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
batch_size = c_ids.size(0) | |
start_symbols = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1) | |
start_symbols = start_symbols.to(c_ids.device) | |
position_ids = torch.zeros_like(start_symbols) | |
q_embeddings = self.embedding(start_symbols, None, position_ids) | |
state = init_state | |
# unroll | |
all_indices = [] | |
all_indices.append(start_symbols) | |
for _ in range(self.max_q_len - 1): | |
q_outputs, state = self.question_lstm.lstm(q_embeddings, state) | |
# attention | |
mask = c_mask.unsqueeze(1).float() | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), c_outputs, mask) | |
# gen logits | |
q_concat = self.concat_linear(torch.cat([q_outputs, c_attned_by_q], dim=2)).tanh() | |
logits = self.logit_linear(q_concat) | |
if self.copy: | |
# copy logits | |
attn_logits = attn_logits.squeeze(1) | |
out = torch.zeros(batch_size, self.ntokens).to(c_ids.device) | |
out = out - 10000.0 | |
out, _ = scatter_max(attn_logits, c_ids, out=out) | |
out = out.masked_fill(out == -10000.0, 0) | |
logits = logits + out.unsqueeze(1) | |
indices = torch.argmax(logits, 2) | |
all_indices.append(indices) | |
q_embeddings = self.embedding(indices, None, position_ids) | |
q_ids = torch.cat(all_indices, 1) | |
eos_mask = q_ids == self.eos_id | |
no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * 63 | |
eos_mask = eos_mask.cpu().numpy() | |
q_lengths = np.argmax(eos_mask, axis=1) + 1 | |
q_lengths = torch.tensor(q_lengths).to(q_ids.device).long() + no_eos_idx_sum | |
batch_size, max_len = q_ids.size() | |
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)) | |
idxes = idxes.unsqueeze(0).to(q_ids.device).repeat(batch_size, 1) | |
q_mask = (idxes < q_lengths.unsqueeze(1)) | |
q_ids = q_ids.long() * q_mask.long() | |
return q_ids | |
class DiscreteVAE(nn.Module): | |
def __init__(self, padding_idx, sos_id, eos_id, | |
bert_model, | |
nhidden, ntokens, nlayers, | |
nz, nzdim, freeze=False, | |
dropout=0, copy=True, max_q_len=64): | |
super(DiscreteVAE, self).__init__() | |
self.nhidden = nhidden | |
if "large" in bert_model: | |
emsize = 1024 | |
else: | |
emsize = 768 | |
self.emsize = emsize | |
self.ntokens = ntokens | |
self.nlayers = nlayers | |
self.nz = nz | |
self.nzdim = nzdim | |
embedding = BertEmbedding(bert_model) | |
if freeze: | |
print("freeze bert embedding") | |
for param in embedding.parameters(): | |
param.requires_grad = False | |
self.posterior_encoder = PosteriorEncoder(embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, nz, nzdim, dropout) | |
self.prior_encoder = PriorEncoder(embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, nz, nzdim, dropout) | |
self.answer_decoder = AnswerDecoder(embedding, bert_model, emsize, | |
nhidden, nlayers, dropout) | |
self.question_decoder = QuestionDecoder(sos_id, eos_id, | |
embedding, bert_model, emsize, | |
nhidden, ntokens, nlayers, dropout, | |
copy, max_q_len) | |
self.q_h_linear = nn.Linear(nz * nzdim, 2 * nlayers * nhidden, False) | |
self.q_c_linear = nn.Linear(nz * nzdim, 2 * nlayers * nhidden, False) | |
self.a_h_linear = nn.Linear(nz * nzdim, nhidden, False) | |
self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx) | |
self.kl_criterion = CatKLLoss() | |
def return_init_state(self, z_flatten): | |
batch_size = z_flatten.size(0) | |
q_init_h = self.q_h_linear(z_flatten) | |
q_init_c = self.q_c_linear(z_flatten) | |
q_init_h = q_init_h.view(batch_size, self.nlayers, 2 * self.nhidden).transpose(0, 1).contiguous() | |
q_init_c = q_init_c.view(batch_size, self.nlayers, 2 * self.nhidden).transpose(0, 1).contiguous() | |
q_init_state = (q_init_h, q_init_c) | |
a_init_h = self.a_h_linear(z_flatten) | |
a_init_state = a_init_h | |
return q_init_state, a_init_state | |
def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions, tau=1.0): | |
max_c_len = c_ids.size(1) | |
posterior_z_logits, posterior_z_prob = self.posterior_encoder(c_ids, q_ids, a_ids) | |
posterior_z = gumbel_softmax(posterior_z_logits, hard=True) | |
posterior_z_flatten = posterior_z.view(-1, self.nz * self.nzdim) | |
prior_z_logits, prior_z_prob = self.prior_encoder(c_ids) | |
q_init_state, a_init_state = self.return_init_state(posterior_z_flatten) | |
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids) | |
q_logits = self.question_decoder(q_init_state, c_ids, q_ids, a_ids) | |
# q rec loss | |
loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(), | |
q_ids[:, 1:]) | |
# a rec loss | |
a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len) | |
start_positions.clamp_(0, max_c_len) | |
end_positions.clamp_(0, max_c_len) | |
loss_start_a_rec = a_rec_criterion(start_logits, start_positions) | |
loss_end_a_rec = a_rec_criterion(end_logits, end_positions) | |
loss_a_rec = (loss_start_a_rec + loss_end_a_rec) / 2 | |
# kl loss | |
posterior_avg_z_prob = posterior_z_prob.mean(dim=0) | |
loss_kl = self.kl_criterion(posterior_avg_z_prob.log(), prior_z_prob.log()).mean(dim=0) | |
loss = loss_q_rec + loss_a_rec + loss_kl | |
return loss, loss_q_rec, loss_a_rec, loss_kl | |
def recon_ans(self, c_ids, q_ids, a_ids): | |
posterior_z_logits, posterior_z_prob = self.posterior_encoder(c_ids, q_ids, a_ids) | |
posterior_z = gumbel_softmax(posterior_z_logits, hard=True) | |
posterior_z_flatten = posterior_z.view(-1, self.nz * self.nzdim) | |
q_init_state, a_init_state = self.return_init_state(posterior_z_flatten) | |
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids) | |
return start_logits, end_logits | |
def generate(self, z_logits, c_ids): | |
batch_size, max_c_len = c_ids.size() | |
c_mask, _ = return_mask_lengths(c_ids) | |
z = gumbel_softmax(z_logits, hard=True) | |
z_flatten = z.view(-1, self.nz * self.nzdim) | |
q_init_state, a_init_state = self.return_init_state(z_flatten) | |
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids) | |
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float()) | |
mask = torch.triu(mask) == 0 | |
score = (F.log_softmax(start_logits, dim=1).unsqueeze(2) | |
+ F.log_softmax(end_logits, dim=1).unsqueeze(1)) | |
score = score.masked_fill(mask, -10000.0) | |
score, start_positions = score.max(dim=1) | |
score, end_positions = score.max(dim=1) | |
start_positions = torch.gather(start_positions, 1, end_positions.view(-1, 1)).squeeze(1) | |
idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len)) | |
idxes = idxes.unsqueeze(0).to(start_logits.device).repeat(batch_size, 1) | |
start_positions = start_positions.unsqueeze(1) | |
start_mask = (idxes >= start_positions).long() | |
end_positions = end_positions.unsqueeze(1) | |
end_mask = (idxes <= end_positions).long() | |
generated_a_ids = start_mask + end_mask - 1 | |
q_ids = self.question_decoder.generate(q_init_state, c_ids, generated_a_ids) | |
return q_ids, start_positions.squeeze(1), end_positions.squeeze(1), start_logits, end_logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from new_model import DiscreteVAE | |
class AttnTrainer(CatTrainer): | |
def __init__(self, args): | |
super(AttnTrainer, self).__init__(args) | |
def init_model(self, args): | |
sos_id = self.tokenizer.vocab["[CLS]"] | |
eos_id = self.tokenizer.vocab["[SEP]"] | |
model = DiscreteVAE(padding_idx=0, | |
sos_id=sos_id, | |
eos_id=eos_id, | |
bert_model="bert-base-uncased", | |
ntokens=len(self.tokenizer.vocab), | |
nhidden=512, | |
nlayers=1, | |
dropout=0.2, | |
nz=20, | |
nzdim=10, | |
freeze=self.args.freeze, | |
copy=True) | |
model = model.to(self.device) | |
return model | |
def get_opt(self): | |
# in case of using pre-trained vae | |
if self.args.save_file is not None: | |
params = [param for name, param in self.model.named_parameters() if "vae_net" not in name] | |
else: | |
params = self.model.parameters() | |
opt = optim.Adam(params, self.args.lr) | |
return opt | |
def process_batch(self, batch): | |
batch = tuple(t.to(self.device) for t in batch) | |
q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions = batch | |
q_len = torch.sum(torch.sign(q_ids), 1) | |
max_len = torch.max(q_len) | |
q_ids = q_ids[:, :max_len] | |
c_len = torch.sum(torch.sign(c_ids), 1) | |
max_len = torch.max(c_len) | |
c_ids = c_ids[:, :max_len] | |
tag_ids = tag_ids[:, :max_len] | |
a_len = torch.sum(torch.sign(ans_ids), 1) | |
max_len = torch.max(a_len) | |
ans_ids = ans_ids[:, :max_len] | |
return q_ids, c_ids, tag_ids, ans_ids, start_positions, end_positions | |
def train(self): | |
batch_num = len(self.train_loader) | |
avg_q_rec = 0 | |
avg_a_rec = 0 | |
avg_kl = 0 | |
global_step = 1 | |
best_f1 = 0 | |
for epoch in range(1, self.args.num_epochs + 1): | |
start = time.time() | |
self.model.train() | |
for step, batch in enumerate(self.train_loader, start=1): | |
# allocate tensors to device | |
q_ids, c_ids, tag_ids, _, start_positions, end_positions = self.process_batch(batch) | |
# forward pass | |
ans_ids = (tag_ids != 0).long() | |
loss, q_rec, a_rec, kl = self.model(c_ids, q_ids, ans_ids, | |
start_positions, end_positions) | |
self.opt.zero_grad() | |
loss.backward() | |
nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) | |
self.opt.step() | |
global_step += 1 | |
avg_q_rec = cal_running_avg_loss(q_rec.item(), avg_q_rec) | |
avg_a_rec = cal_running_avg_loss(a_rec.item(), avg_a_rec) | |
avg_kl = cal_running_avg_loss(kl.item(), avg_kl) | |
msg = "{}/{} {} - ETA : {} - Q recon: {:.4f}, A recon: {:.4f}, kl: {:.4f}" \ | |
.format(step, batch_num, progress_bar(step, batch_num), | |
eta(start, step, batch_num), avg_q_rec, avg_a_rec, avg_kl) | |
print(msg, end="\r") | |
if not self.args.debug: | |
eval_dict = self.eval(msg) | |
f1 = eval_dict["f1"] | |
em = eval_dict["exact_match"] | |
print("Epoch {} took {} - final Q-rec : {:.3f}, final A-rec: {:.3f}, " | |
"F1 : {:.2f}, EM: {:.2f} " | |
.format(epoch, user_friendly_time(time_since(start)), | |
avg_q_rec, avg_a_rec, f1, em)) | |
if f1 > best_f1: | |
best_f1 = f1 | |
self.save_model_kl(epoch, f1, em) | |
@staticmethod | |
def get_seq_len(input_ids, eos_id): | |
# input_ids: [b, t] | |
# eos_id : scalar | |
mask = (input_ids == eos_id).byte() | |
num_eos = torch.sum(mask, 1) | |
# change Tensor to cpu because torch.argmax works differently in cuda and cpu | |
# but np.argmax is consistent it returns the first index of the maximum element | |
mask = mask.cpu().numpy() | |
indices = np.argmax(mask, 1) | |
# convert numpy array to Tensor | |
seq_len = torch.LongTensor(indices).to(input_ids.device) | |
# in case there is no eos in the sequence | |
max_len = input_ids.size(1) | |
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1) | |
# +1 for eos | |
seq_len = seq_len + 1 | |
return seq_len | |
def save_model_kl(self, epoch, nll, kl): | |
nll = round(nll, 2) | |
kl = round(kl, 2) | |
save_file = os.path.join(self.save_dir, "{}_{:.2f}_{:.2f}".format(epoch, nll, kl)) | |
state_dict = self.model.state_dict() | |
torch.save(state_dict, save_file) | |
def eval(self, msg): | |
num_val_batches = len(self.dev_loader) | |
all_results = [] | |
RawResult = collections.namedtuple("RawResult", | |
["unique_id", "start_logits", "end_logits"]) | |
example_index = -1 | |
self.model.eval() | |
for i, batch in enumerate(self.dev_loader, start=1): | |
q_ids, c_ids, tag_ids, _, _, _ = self.process_batch(batch) | |
ans_ids = (tag_ids != 0).long() | |
with torch.no_grad(): | |
batch_start_logits, batch_end_logits = self.model.recon_ans(c_ids, q_ids, ans_ids) | |
batch_size = batch_start_logits.size(0) | |
for j in range(batch_size): | |
example_index += 1 | |
start_logits = batch_start_logits[j].detach().cpu().tolist() | |
end_logits = batch_end_logits[j].detach().cpu().tolist() | |
eval_feature = self.eval_features[example_index] | |
unique_id = int(eval_feature.unique_id) | |
all_results.append(RawResult(unique_id=unique_id, | |
start_logits=start_logits, | |
end_logits=end_logits)) | |
msg2 = "{} => Evaluating :{}/{}".format(msg, i, num_val_batches) | |
print(msg2, end="\r") | |
output_prediction_file = os.path.join(self.save_dir, "recon_pred.json") | |
write_predictions(self.eval_examples, self.eval_features, all_results, | |
n_best_size=20, max_answer_length=30, do_lower_case=True, | |
output_prediction_file=output_prediction_file, | |
verbose_logging=False, | |
version_2_with_negative=False, | |
null_score_diff_threshold=0, | |
noq_position=True) | |
with open(self.args.dev_file) as f: | |
data_json = json.load(f) | |
dataset = data_json["data"] | |
with open(output_prediction_file) as prediction_file: | |
predictions = json.load(prediction_file) | |
results = evaluate(dataset, predictions) | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment