Created
February 15, 2020 07:23
-
-
Save seanie12/6835a3fef0bb1dd1b9590cb78e84b6b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
from torch_scatter import scatter_max | |
from pytorch_transformers import BertModel, BertTokenizer | |
def return_mask_lengths(ids): | |
mask = torch.sign(ids).float() | |
lengths = mask.sum(dim=1).long() | |
return mask, lengths | |
def return_num(model): | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
params = sum([np.prod(p.size()) for p in model_parameters]) | |
return params | |
def cal_attn(left, right, mask): | |
mask = (1.0 - mask.float()) * -10000.0 | |
attn_logits = torch.matmul(left, right.transpose(-1, -2).contiguous()) | |
attn_logits = attn_logits + mask | |
attn_weights = F.softmax(input=attn_logits, dim=-1) | |
attn_outputs = torch.matmul(attn_weights, right) | |
return attn_outputs, attn_logits | |
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-20, dim=-1): | |
# type: (Tensor, float, bool, float, int) -> Tensor | |
gumbels = -(torch.empty_like(logits).exponential_() + eps).log() # ~Gumbel(0,1) | |
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) | |
y_soft = gumbels.softmax(dim) | |
if hard: | |
# Straight through. | |
index = y_soft.max(dim, keepdim=True)[1] | |
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) | |
ret = y_hard - y_soft.detach() + y_soft | |
else: | |
# Re-parametrization trick. | |
ret = y_soft | |
return ret | |
class CategoricalKLLoss(nn.Module): | |
def __init__(self): | |
super(CategoricalKLLoss, self).__init__() | |
def forward(self, P, Q): | |
log_P = P.log() | |
log_Q = Q.log() | |
kl = (P * (log_P - log_Q)).sum(dim=-1).sum(dim=-1) | |
return kl.mean(dim=0) | |
class GaussianKLLoss(nn.Module): | |
def __init__(self): | |
super(GaussianKLLoss, self).__init__() | |
def forward(self, mu1, logvar1, mu2, logvar2): | |
numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2) | |
fraction = torch.div(numerator, (logvar2.exp())) | |
kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, dim=1) | |
return kl.mean(dim=0) | |
class Embedding(nn.Module): | |
def __init__(self, bert_model): | |
super(Embedding, self).__init__() | |
bert_embeddings = BertModel.from_pretrained(bert_model).embeddings | |
self.word_embeddings = bert_embeddings.word_embeddings | |
self.token_type_embeddings = bert_embeddings.token_type_embeddings | |
self.position_embeddings = bert_embeddings.position_embeddings | |
self.LayerNorm = bert_embeddings.LayerNorm | |
self.dropout = bert_embeddings.dropout | |
def forward(self, input_ids, token_type_ids=None, position_ids=None): | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
if position_ids is None: | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
words_embeddings = self.word_embeddings(input_ids) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
position_embeddings = self.position_embeddings(position_ids) | |
embeddings = words_embeddings + token_type_embeddings + position_embeddings | |
embeddings = self.LayerNorm(embeddings) | |
embeddings = self.dropout(embeddings) | |
return embeddings | |
class ContextualizedEmbedding(nn.Module): | |
def __init__(self, bert_model): | |
super(ContextualizedEmbedding, self).__init__() | |
bert = BertModel.from_pretrained(bert_model) | |
self.embedding = bert.embeddings | |
self.encoder = bert.encoder | |
self.num_hidden_layers = bert.config.num_hidden_layers | |
def forward(self, input_ids, attention_mask, token_type_ids=None): | |
if token_type_ids is None: | |
token_type_ids = torch.zeros_like(input_ids) | |
seq_length = input_ids.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).float() | |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
head_mask = [None] * self.num_hidden_layers | |
embedding_output = self.embedding(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) | |
encoder_outputs = self.encoder(embedding_output, | |
extended_attention_mask, | |
head_mask=head_mask) | |
sequence_output = encoder_outputs[0] | |
return sequence_output | |
class CustomLSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional=False): | |
super(CustomLSTM, self).__init__() | |
self.num_layers = num_layers | |
self.hidden_size = hidden_size | |
self.bidirectional = bidirectional | |
self.dropout = nn.Dropout(dropout) | |
if dropout > 0.0 and num_layers == 1: | |
dropout = 0.0 | |
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, | |
num_layers=num_layers, dropout=dropout, | |
bidirectional=bidirectional, batch_first=True) | |
def forward(self, input, input_lengths, state=None): | |
batch_size, total_length, _ = input.size() | |
input_packed = pack_padded_sequence(input, input_lengths, | |
batch_first=True, enforce_sorted=False) | |
self.lstm.flatten_parameters() | |
output_packed, state = self.lstm(input_packed, state) | |
output = pad_packed_sequence(output_packed, batch_first=True, total_length=total_length)[0] | |
output = self.dropout(output) | |
return output, state | |
class PosteriorEncoder(nn.Module): | |
def __init__(self, embedding, emsize, | |
nhidden, nlayers, | |
nzqdim, nza, nzadim, | |
dropout=0.0): | |
super(PosteriorEncoder, self).__init__() | |
self.embedding = embedding | |
self.nhidden = nhidden | |
self.nlayers = nlayers | |
self.nzqdim = nzqdim | |
self.nza = nza | |
self.nzadim = nzadim | |
self.question_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.context_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.context_answer_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.question_attention = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.context_attention = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.zq_attention = nn.Linear(nzqdim, 2 * nhidden) | |
self.zq_linear = nn.Linear(4 * 2 * nhidden, 2 * nzqdim) | |
self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim) | |
def forward(self, c_ids, q_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
q_mask, q_lengths = return_mask_lengths(q_ids) | |
# question enc | |
q_embeddings = self.embedding(q_ids) | |
q_hs, q_state = self.question_encoder(q_embeddings, q_lengths) | |
q_h = q_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
q_h = q_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
# context enc | |
c_embeddings = self.embedding(c_ids) | |
c_hs, c_state = self.question_encoder(c_embeddings, c_lengths) | |
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
# context and answer enc | |
c_a_embeddings = self.embedding(c_ids, a_ids, None) | |
c_a_hs, c_a_state = self.question_encoder(c_a_embeddings, c_lengths) | |
c_a_h = c_a_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
c_a_h = c_a_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
# attetion q, c | |
mask = c_mask.unsqueeze(1) | |
c_attned_by_q, _ = cal_attn(self.question_attention(q_h).unsqueeze(1), c_hs, mask) | |
c_attned_by_q = c_attned_by_q.squeeze(1) | |
# attetion c, q | |
mask = q_mask.unsqueeze(1) | |
q_attned_by_c, _ = cal_attn(self.context_attention(c_h).unsqueeze(1), q_hs, mask) | |
q_attned_by_c = q_attned_by_c.squeeze(1) | |
h = torch.cat([q_h, q_attned_by_c, c_h, c_attned_by_q], dim=-1) | |
zq_mu, zq_logvar = torch.split(self.zq_linear(h), self.nzqdim, dim=1) | |
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar) | |
# attention zq, c_a | |
mask = c_mask.unsqueeze(1) | |
c_a_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_a_hs, mask) | |
c_a_attned_by_zq = c_a_attned_by_zq.squeeze(1) | |
h = torch.cat([zq, c_a_attned_by_zq, c_a_h], dim=-1) | |
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim) | |
za_prob = F.softmax(za_logits, dim=-1) | |
za = gumbel_softmax(za_logits, hard=True) | |
return zq_mu, zq_logvar, zq, za_prob, za | |
class PriorEncoder(nn.Module): | |
def __init__(self, embedding, emsize, | |
nhidden, nlayers, | |
nzqdim, nza, nzadim, | |
dropout=0): | |
super(PriorEncoder, self).__init__() | |
self.embedding = embedding | |
self.nhidden = nhidden | |
self.nlayers = nlayers | |
self.nzqdim = nzqdim | |
self.nza = nza | |
self.nzadim = nzadim | |
self.context_encoder = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.zq_attention = nn.Linear(nzqdim, 2 * nhidden) | |
self.zq_linear = nn.Linear(2 * nhidden, 2 * nzqdim) | |
self.za_linear = nn.Linear(nzqdim + 2 * 2 * nhidden, nza * nzadim) | |
def forward(self, c_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_embeddings = self.embedding(c_ids) | |
c_hs, c_state = self.context_encoder(c_embeddings, c_lengths) | |
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
zq_mu, zq_logvar = torch.split(self.zq_linear(c_h), self.nzqdim, dim=1) | |
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar) | |
mask = c_mask.unsqueeze(1) | |
c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_hs, mask) | |
c_attned_by_zq = c_attned_by_zq.squeeze(1) | |
h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1) | |
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim) | |
za_prob = F.softmax(za_logits, dim=-1) | |
za = gumbel_softmax(za_logits, hard=True) | |
return zq_mu, zq_logvar, zq, za_prob, za | |
def interpolation(self, c_ids, zq): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_embeddings = self.embedding(c_ids) | |
c_hs, c_state = self.context_encoder(c_embeddings, c_lengths) | |
c_h = c_state[0].view(self.nlayers, 2, -1, self.nhidden)[-1] | |
c_h = c_h.transpose(0, 1).contiguous().view(-1, 2 * self.nhidden) | |
mask = c_mask.unsqueeze(1) | |
c_attned_by_zq, _ = cal_attn(self.zq_attention(zq).unsqueeze(1), c_hs, mask) | |
c_attned_by_zq = c_attned_by_zq.squeeze(1) | |
h = torch.cat([zq, c_attned_by_zq, c_h], dim=-1) | |
za_logits = self.za_linear(h).view(-1, self.nza, self.nzadim) | |
za_prob = F.softmax(za_logits, dim=-1) | |
za = gumbel_softmax(za_logits, hard=True) | |
return za | |
class AnswerDecoder(nn.Module): | |
def __init__(self, embedding, emsize, | |
nhidden, nlayers, | |
dropout=0.0): | |
super(AnswerDecoder, self).__init__() | |
self.embedding = embedding | |
self.context_lstm = CustomLSTM(input_size=4 * emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.start_linear = nn.Linear(2 * nhidden, 1) | |
self.end_linear = nn.Linear(2 * nhidden, 1) | |
self.ls = nn.LogSoftmax(dim=1) | |
def forward(self, init_state, c_ids): | |
batch_size, max_c_len = c_ids.size() | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
H = self.embedding(c_ids, c_mask) | |
U = init_state.unsqueeze(1).repeat(1, max_c_len, 1) | |
G = torch.cat([H, U, H * U, torch.abs(H - U)], dim=-1) | |
M, _ = self.context_lstm(G, c_lengths) | |
start_logits = self.start_linear(M).squeeze(-1) | |
end_logits = self.end_linear(M).squeeze(-1) | |
start_end_mask = (c_mask == 0) | |
masked_start_logits = start_logits.masked_fill(start_end_mask, -10000.0) | |
masked_end_logits = end_logits.masked_fill(start_end_mask, -10000.0) | |
return masked_start_logits, masked_end_logits | |
def generate(self, init_state, c_ids): | |
start_logits, end_logits = self.forward(init_state, c_ids) | |
c_mask, _ = return_mask_lengths(c_ids) | |
batch_size, max_c_len = c_ids.size() | |
mask = torch.matmul(c_mask.unsqueeze(2).float(), c_mask.unsqueeze(1).float()) | |
mask = torch.triu(mask) == 0 | |
score = (self.ls(start_logits).unsqueeze(2) + self.ls(end_logits).unsqueeze(1)) | |
score = score.masked_fill(mask, -10000.0) | |
score, start_positions = score.max(dim=1) | |
score, end_positions = score.max(dim=1) | |
start_positions = torch.gather(start_positions, 1, end_positions.view(-1, 1)).squeeze(1) | |
idxes = torch.arange(0, max_c_len, out=torch.LongTensor(max_c_len)) | |
idxes = idxes.unsqueeze(0).to(start_logits.device).repeat(batch_size, 1) | |
start_positions = start_positions.unsqueeze(1) | |
start_mask = (idxes >= start_positions).long() | |
end_positions = end_positions.unsqueeze(1) | |
end_mask = (idxes <= end_positions).long() | |
a_ids = start_mask + end_mask - 1 | |
return a_ids, start_positions.squeeze(1), end_positions.squeeze(1) | |
class ContextEncoderforQG(nn.Module): | |
def __init__(self, embedding, emsize, | |
nhidden, nlayers, | |
dropout=0.0): | |
super(ContextEncoderforQG, self).__init__() | |
self.embedding = embedding | |
self.context_lstm = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=True) | |
self.context_linear = nn.Linear(2 * nhidden, 2 * nhidden) | |
self.fusion = nn.Linear(4 * nhidden, 2 * nhidden, bias=False) | |
self.gate = nn.Linear(4 * nhidden, 2 * nhidden, bias=False) | |
def forward(self, c_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_embeddings = self.embedding(c_ids, c_mask, a_ids) | |
c_outputs, _ = self.context_lstm(c_embeddings, c_lengths) | |
# attention | |
mask = torch.matmul(c_mask.unsqueeze(2), c_mask.unsqueeze(1)) | |
c_attned_by_c, _ = cal_attn(self.context_linear(c_outputs), | |
c_outputs, | |
mask) | |
c_concat = torch.cat([c_outputs, c_attned_by_c], dim=2) | |
c_fused = self.fusion(c_concat).tanh() | |
c_gate = self.gate(c_concat).sigmoid() | |
c_outputs = c_gate * c_fused + (1 - c_gate) * c_outputs | |
return c_outputs | |
class QuestionDecoder(nn.Module): | |
def __init__(self, sos_id, eos_id, | |
embedding, contextualized_embedding, emsize, | |
nhidden, ntokens, nlayers, | |
dropout=0.0, | |
max_q_len=64): | |
super(QuestionDecoder, self).__init__() | |
self.sos_id = sos_id | |
self.eos_id = eos_id | |
self.emsize = emsize | |
self.embedding = embedding | |
self.nhidden = nhidden | |
self.ntokens = ntokens | |
self.nlayers = nlayers | |
# this max_len include sos eos | |
self.max_q_len = max_q_len | |
self.context_lstm = ContextEncoderforQG(contextualized_embedding, emsize, | |
nhidden // 2, nlayers, dropout) | |
self.question_lstm = CustomLSTM(input_size=emsize, | |
hidden_size=nhidden, | |
num_layers=nlayers, | |
dropout=dropout, | |
bidirectional=False) | |
self.question_linear = nn.Linear(nhidden, nhidden) | |
self.concat_linear = nn.Sequential(nn.Linear(2 * nhidden, 2 * nhidden), | |
nn.Dropout(dropout), | |
nn.Linear(2 * nhidden, 2 * emsize)) | |
self.logit_linear = nn.Linear(emsize, ntokens, bias=False) | |
# fix output word matrix | |
self.logit_linear.weight = embedding.word_embeddings.weight | |
for param in self.logit_linear.parameters(): | |
param.requires_grad = False | |
self.discriminator = nn.Bilinear(emsize, nhidden, 1) | |
def postprocess(self, q_ids): | |
eos_mask = q_ids == self.eos_id | |
no_eos_idx_sum = (eos_mask.sum(dim=1) == 0).long() * (self.max_q_len - 1) | |
eos_mask = eos_mask.cpu().numpy() | |
q_lengths = np.argmax(eos_mask, axis=1) + 1 | |
q_lengths = torch.tensor(q_lengths).to(q_ids.device).long() + no_eos_idx_sum | |
batch_size, max_len = q_ids.size() | |
idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)) | |
idxes = idxes.unsqueeze(0).to(q_ids.device).repeat(batch_size, 1) | |
q_mask = (idxes < q_lengths.unsqueeze(1)) | |
q_ids = q_ids.long() * q_mask.long() | |
return q_ids | |
def forward(self, init_state, c_ids, q_ids, a_ids): | |
batch_size, max_q_len = q_ids.size() | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
q_mask, q_lengths = return_mask_lengths(q_ids) | |
# question dec | |
q_embeddings = self.embedding(q_ids) | |
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state) | |
# attention | |
mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1)) | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), | |
c_outputs, | |
mask) | |
# gen logits | |
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2) | |
q_concated = self.concat_linear(q_concated) | |
q_maxouted, _ = q_concated.view(batch_size, max_q_len, self.emsize, 2).max(dim=-1) | |
gen_logits = self.logit_linear(q_maxouted) | |
# copy logits | |
bq = batch_size * max_q_len | |
c_ids = c_ids.unsqueeze(1).repeat(1, max_q_len, 1).view(bq, -1).contiguous() | |
attn_logits = attn_logits.view(bq, -1).contiguous() | |
copy_logits = torch.zeros(bq, self.ntokens).to(c_ids.device) | |
copy_logits = copy_logits - 10000.0 | |
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits) | |
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0) | |
copy_logits = copy_logits.view(batch_size, max_q_len, -1).contiguous() | |
logits = gen_logits + copy_logits | |
# mutual information btw answer and question | |
a_emb = c_outputs * a_ids.float().unsqueeze(2) | |
a_mean_emb = torch.sum(a_emb, dim=1) / a_ids.sum(dim=1).unsqueeze(1).float() | |
fake_a_mean_emb = torch.cat([a_mean_emb[-1].unsqueeze(0), a_mean_emb[:-1]], dim=0) | |
q_emb = q_maxouted * q_mask.unsqueeze(2) | |
q_mean_emb = torch.sum(q_emb, dim=1) / q_lengths.unsqueeze(1).float() | |
fake_q_mean_emb = torch.cat([q_mean_emb[-1].unsqueeze(0), q_mean_emb[:-1]], dim=0) | |
bce_loss = nn.BCEWithLogitsLoss() | |
true_logits = self.discriminator(q_mean_emb, a_mean_emb) | |
true_labels = torch.ones_like(true_logits) | |
fake_a_logits = self.discriminator(q_mean_emb, fake_a_mean_emb) | |
fake_q_logits = self.discriminator(fake_q_mean_emb, a_mean_emb) | |
fake_logits = torch.cat([fake_a_logits, fake_q_logits], dim=0) | |
fake_labels = torch.zeros_like(fake_logits) | |
true_loss = bce_loss(true_logits, true_labels) | |
fake_loss = 0.5 * bce_loss(fake_logits, fake_labels) | |
loss_info = 0.5 * (true_loss + fake_loss) | |
print(logits.size()) | |
return logits, loss_info | |
def get_mi(self, init_state, c_ids, q_ids, a_ids): | |
batch_size, max_q_len = q_ids.size() | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
q_mask, q_lengths = return_mask_lengths(q_ids) | |
# question dec | |
q_embeddings = self.embedding(q_ids) | |
q_outputs, _ = self.question_lstm(q_embeddings, q_lengths, init_state) | |
# attention | |
mask = torch.matmul(q_mask.unsqueeze(2), c_mask.unsqueeze(1)) | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), | |
c_outputs, | |
mask) | |
# gen logits | |
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2) | |
q_concated = self.concat_linear(q_concated) | |
q_maxouted, _ = q_concated.view(batch_size, max_q_len, self.emsize, 2).max(dim=-1) | |
# mutual information btw answer and question | |
a_emb = c_outputs * a_ids.float().unsqueeze(2) | |
a_mean_emb = torch.sum(a_emb, dim=1) / a_ids.sum(dim=1).unsqueeze(1).float() | |
q_emb = q_maxouted * q_mask.unsqueeze(2) | |
q_mean_emb = torch.sum(q_emb, dim=1) / q_lengths.unsqueeze(1).float() | |
logits = self.discriminator(q_mean_emb, a_mean_emb) | |
logits = logits.squeeze(1) | |
return logits | |
def generate(self, init_state, c_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
batch_size = c_ids.size(0) | |
q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1) | |
q_ids = q_ids.to(c_ids.device) | |
token_type_ids = torch.zeros_like(q_ids) | |
position_ids = torch.zeros_like(q_ids) | |
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids) | |
state = init_state | |
# unroll | |
all_q_ids = list() | |
all_q_ids.append(q_ids) | |
for _ in range(self.max_q_len - 1): | |
position_ids = position_ids + 1 | |
q_outputs, state = self.question_lstm.lstm(q_embeddings, state) | |
# attention | |
mask = c_mask.unsqueeze(1) | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), | |
c_outputs, | |
mask) | |
# gen logits | |
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2) | |
q_concated = self.concat_linear(q_concated) | |
q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1) | |
gen_logits = self.logit_linear(q_maxouted) | |
# copy logits | |
attn_logits = attn_logits.squeeze(1) | |
copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device) | |
copy_logits = copy_logits - 10000.0 | |
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits) | |
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0) | |
logits = gen_logits + copy_logits.unsqueeze(1) | |
q_ids = torch.argmax(logits, 2) | |
all_q_ids.append(q_ids) | |
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids) | |
q_ids = torch.cat(all_q_ids, 1) | |
q_ids = self.postprocess(q_ids) | |
return q_ids | |
def sample(self, init_state, c_ids, a_ids): | |
c_mask, c_lengths = return_mask_lengths(c_ids) | |
c_outputs = self.context_lstm(c_ids, a_ids) | |
batch_size = c_ids.size(0) | |
q_ids = torch.LongTensor([self.sos_id] * batch_size).unsqueeze(1) | |
q_ids = q_ids.to(c_ids.device) | |
token_type_ids = torch.zeros_like(q_ids) | |
position_ids = torch.zeros_like(q_ids) | |
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids) | |
state = init_state | |
# unroll | |
all_q_ids = list() | |
all_q_ids.append(q_ids) | |
for _ in range(self.max_q_len - 1): | |
position_ids = position_ids + 1 | |
q_outputs, state = self.question_lstm.lstm(q_embeddings, state) | |
# attention | |
mask = c_mask.unsqueeze(1) | |
c_attned_by_q, attn_logits = cal_attn(self.question_linear(q_outputs), | |
c_outputs, | |
mask) | |
# gen logits | |
q_concated = torch.cat([q_outputs, c_attned_by_q], dim=2) | |
q_concated = self.concat_linear(q_concated) | |
q_maxouted, _ = q_concated.view(batch_size, 1, self.emsize, 2).max(dim=-1) | |
gen_logits = self.logit_linear(q_maxouted) | |
# copy logits | |
attn_logits = attn_logits.squeeze(1) | |
copy_logits = torch.zeros(batch_size, self.ntokens).to(c_ids.device) | |
copy_logits = copy_logits - 10000.0 | |
copy_logits, _ = scatter_max(attn_logits, c_ids, out=copy_logits) | |
copy_logits = copy_logits.masked_fill(copy_logits == -10000.0, 0) | |
logits = gen_logits + copy_logits.unsqueeze(1) | |
logits = logits.squeeze(1) | |
logits =self.top_k_top_p_filtering(logits, 2, top_p=0.8) | |
probs = F.softmax(logits, dim=-1) | |
q_ids = torch.multinomial(probs, num_samples=1) # [b,1] | |
all_q_ids.append(q_ids) | |
q_embeddings = self.embedding(q_ids, token_type_ids, position_ids) | |
q_ids = torch.cat(all_q_ids, 1) | |
q_ids = self.postprocess(q_ids) | |
return q_ids | |
def top_k_top_p_filtering(self, logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (batch size x vocabulary size) | |
top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
""" | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[ | |
0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum( | |
F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., | |
1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
dim=1, index=sorted_indices, src=sorted_indices_to_remove) | |
logits[indices_to_remove] = filter_value | |
return logits | |
class DiscreteVAE(nn.Module): | |
def __init__(self, args): | |
super(DiscreteVAE, self).__init__() | |
tokenizer = BertTokenizer.from_pretrained(args.bert_model) | |
padding_idx = tokenizer.vocab['[PAD]'] | |
sos_id = tokenizer.vocab['[CLS]'] | |
eos_id = tokenizer.vocab['[SEP]'] | |
ntokens = len(tokenizer.vocab) | |
bert_model = args.bert_model | |
if "large" in bert_model: | |
emsize = 1024 | |
else: | |
emsize = 768 | |
enc_nhidden = args.enc_nhidden | |
enc_nlayers = args.enc_nlayers | |
enc_dropout = args.enc_dropout | |
dec_a_nhidden = args.dec_a_nhidden | |
dec_a_nlayers = args.dec_a_nlayers | |
dec_a_dropout = args.dec_a_dropout | |
self.dec_q_nhidden = dec_q_nhidden = args.dec_q_nhidden | |
self.dec_q_nlayers = dec_q_nlayers = args.dec_q_nlayers | |
dec_q_dropout = args.dec_q_dropout | |
self.nzqdim = nzqdim = args.nzqdim | |
self.nza = nza = args.nza | |
self.nzadim = nzadim = args.nzadim | |
self.lambda_kl = args.lambda_kl | |
self.lambda_info = args.lambda_info | |
max_q_len = args.max_q_len | |
embedding = Embedding(bert_model) | |
contextualized_embedding = ContextualizedEmbedding(bert_model) | |
for param in embedding.parameters(): | |
param.requires_grad = False | |
for param in contextualized_embedding.parameters(): | |
param.requires_grad = False | |
self.posterior_encoder = PosteriorEncoder(embedding, emsize, | |
enc_nhidden, enc_nlayers, | |
nzqdim, nza, nzadim, | |
enc_dropout) | |
self.prior_encoder = PriorEncoder(embedding, emsize, | |
enc_nhidden, enc_nlayers, | |
nzqdim, nza, nzadim, enc_dropout) | |
self.answer_decoder = AnswerDecoder(contextualized_embedding, emsize, | |
dec_a_nhidden, dec_a_nlayers, | |
dec_a_dropout) | |
self.question_decoder = QuestionDecoder(sos_id, eos_id, | |
embedding, contextualized_embedding, emsize, | |
dec_q_nhidden, ntokens, dec_q_nlayers, | |
dec_q_dropout, | |
max_q_len) | |
self.q_h_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden) | |
self.q_c_linear = nn.Linear(nzqdim, dec_q_nlayers * dec_q_nhidden) | |
self.a_linear = nn.Linear(nza * nzadim, emsize, False) | |
self.q_rec_criterion = nn.CrossEntropyLoss(ignore_index=padding_idx) | |
self.gaussian_kl_criterion = GaussianKLLoss() | |
self.categorical_kl_criterion = CategoricalKLLoss() | |
def return_init_state(self, zq, za): | |
q_init_h = self.q_h_linear(zq) | |
q_init_c = self.q_c_linear(zq) | |
q_init_h = q_init_h.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous() | |
q_init_c = q_init_c.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous() | |
q_init_state = (q_init_h, q_init_c) | |
za_flatten = za.view(-1, self.nza * self.nzadim) | |
a_init_state = self.a_linear(za_flatten) | |
return q_init_state, a_init_state | |
def forward(self, c_ids, q_ids, a_ids, start_positions, end_positions): | |
posterior_zq_mu, posterior_zq_logvar, posterior_zq, \ | |
posterior_za_prob, posterior_za \ | |
= self.posterior_encoder(c_ids, q_ids, a_ids) | |
prior_zq_mu, prior_zq_logvar, prior_zq, \ | |
prior_za_prob, prior_za \ | |
= self.prior_encoder(c_ids) | |
q_init_state, a_init_state = self.return_init_state(posterior_zq, posterior_za) | |
# answer decoding | |
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids) | |
# question decoding | |
q_logits, loss_info = self.question_decoder(q_init_state, c_ids, q_ids, a_ids) | |
# q rec loss | |
loss_q_rec = self.q_rec_criterion(q_logits[:, :-1, :].transpose(1, 2).contiguous(), | |
q_ids[:, 1:]) | |
# a rec loss | |
max_c_len = c_ids.size(1) | |
a_rec_criterion = nn.CrossEntropyLoss(ignore_index=max_c_len) | |
start_positions.clamp_(0, max_c_len) | |
end_positions.clamp_(0, max_c_len) | |
loss_start_a_rec = a_rec_criterion(start_logits, start_positions) | |
loss_end_a_rec = a_rec_criterion(end_logits, end_positions) | |
loss_a_rec = 0.5 * (loss_start_a_rec + loss_end_a_rec) | |
# kl loss | |
loss_zq_kl = self.gaussian_kl_criterion(posterior_zq_mu, | |
posterior_zq_logvar, | |
prior_zq_mu, | |
prior_zq_logvar) | |
loss_za_kl = self.categorical_kl_criterion(posterior_za_prob, | |
prior_za_prob) | |
loss_kl = self.lambda_kl * (loss_zq_kl + loss_za_kl) | |
loss_info = self.lambda_info * loss_info | |
loss = loss_q_rec + loss_a_rec + loss_kl + loss_info | |
return loss, \ | |
loss_q_rec, loss_a_rec, \ | |
loss_zq_kl, loss_za_kl, \ | |
loss_info | |
def generate(self, zq, za, c_ids): | |
q_init_state, a_init_state = self.return_init_state(zq, za) | |
a_ids, start_positions, end_positions = self.answer_decoder.generate(a_init_state, c_ids) | |
q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids) | |
return q_ids, start_positions, end_positions, a_ids | |
def nucleus_sample(self, zq, za, c_ids): | |
q_init_state, a_init_state = self.return_init_state(zq, za) | |
a_ids, start_positions, end_positions = self.answer_decoder.generate(a_init_state, c_ids) | |
q_ids = self.question_decoder.sample(q_init_state, c_ids, a_ids) | |
return q_ids, start_positions, end_positions, a_ids | |
def return_answer_logits(self, zq, za, c_ids): | |
q_init_state, a_init_state = self.return_init_state(zq, za) | |
start_logits, end_logits = self.answer_decoder(a_init_state, c_ids) | |
return start_logits, end_logits | |
def question_generate(self, c_ids, a_ids): | |
zq_mu, zq_logvar, _, _, _ = self.prior_encoder(c_ids) | |
zq = zq_mu + torch.randn_like(zq_mu) * torch.exp(0.5 * zq_logvar) | |
zq_rand = torch.rand_like(zq_mu) | |
zq = zq[0].unsqueeze(0) | |
zq = torch.cat([zq, zq_rand[1:]], dim=0) | |
q_init_h = self.q_h_linear(zq) | |
q_init_c = self.q_c_linear(zq) | |
q_init_h = q_init_h.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous() | |
q_init_c = q_init_c.view(-1, self.dec_q_nlayers, self.dec_q_nhidden).transpose(0, 1).contiguous() | |
q_init_state = (q_init_h, q_init_c) | |
q_ids = self.question_decoder.generate(q_init_state, c_ids, a_ids) | |
return q_ids | |
def estimate_mi(self, c_ids, q_ids, a_ids): | |
posterior_zq_mu, posterior_zq_logvar, posterior_zq, \ | |
posterior_za_prob, posterior_za \ | |
= self.posterior_encoder(c_ids, q_ids, a_ids) | |
q_init_state, a_init_state = self.return_init_state(posterior_zq, posterior_za) | |
mi = self.question_decoder.get_mi(q_init_state, c_ids, q_ids, a_ids) | |
return mi |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment