Created
August 9, 2019 00:22
-
-
Save seanie12/e840e7cb85cb334ee559de8eabc9723b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pytorch_pretrained_bert import BertModel | |
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
from torch_scatter import scatter_max | |
import numpy as np | |
from torch.distributions.categorical import Categorical | |
INF = 1e12 | |
EOS_ID = 102 | |
class CatEncoder(nn.Module): | |
def __init__(self, embedding_size, hidden_size, | |
num_vars, num_classes): | |
super(CatEncoder, self).__init__() | |
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
self.embedding.requires_grad = False | |
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
bidirectional=True, num_layers=1) | |
self.recog_layer = nn.Linear(2 * hidden_size, num_vars * num_classes) | |
def forward(self, q_ids, q_len): | |
if q_ids.dim() == 2: | |
embedded = self.embedding(q_ids) | |
else: | |
embedded = self.get_embedding(q_ids) | |
packed = pack_padded_sequence(embedded, q_len, | |
batch_first=True, | |
enforce_sorted=False) | |
output, states = self.lstm(packed) | |
output, _ = pad_packed_sequence(output, batch_first=True) | |
hiddens = states[0] # [2, b, d] | |
_, b, d = hiddens.size() | |
concat_hidden = torch.cat([h for h in hiddens], dim=-1) # [b,2*d] | |
# logits for K categorical variables | |
qz_logits = self.recog_layer(concat_hidden) | |
return qz_logits | |
def get_embedding(self, vocab_dist): | |
# vocab_dist : [b,t,|V|] | |
batch_size, nsteps, _ = vocab_dist.size() | |
token_type_ids = torch.zeros((batch_size, nsteps), dtype=torch.long).to(vocab_dist.device) | |
position_ids = torch.arange(nsteps, dtype=torch.long, device=vocab_dist.device) | |
position_ids = position_ids.unsqueeze(0).repeat([batch_size, 1]) | |
embedding_matrix = self.embedding.word_embeddings.weight | |
word_embeddings = torch.matmul(vocab_dist, embedding_matrix) | |
position_embeddings = self.embedding.position_embeddings(position_ids) | |
token_type_embeddings = self.embedding.token_type_embeddings(token_type_ids) | |
embeddings = word_embeddings + position_embeddings + token_type_embeddings | |
embeddings = self.embedding.LayerNorm(embeddings) | |
embeddings = self.embedding.dropout(embeddings) | |
return embeddings | |
class CatDecoder(nn.Module): | |
def __init__(self, vocab_size, embedding_size, hidden_size): | |
super(CatDecoder, self).__init__() | |
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
self.embedding.requires_grad = False | |
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
num_layers=1) | |
self.logit_layer = nn.Linear(hidden_size, vocab_size) | |
def forward(self, q_ids, init_states): | |
batch_size, max_len = q_ids.size() | |
logits = [] | |
states = init_states | |
for i in range(max_len): | |
q_i = q_ids[:, i] | |
embedded = self.embedding(q_i.unsqueeze(1)) | |
hidden, states = self.lstm(embedded, states) | |
logit = self.logit_layer(hidden) # [b,1,|V|] | |
logits.append(logit) | |
logits = torch.cat(logits, dim=1) | |
return logits | |
def decode(self, sos_tokens, init_states, max_step): | |
inputs = sos_tokens | |
prev_states = init_states | |
decoded_ids = [] | |
for i in range(max_step): | |
embedded = self.embedding(inputs.unsqueeze(1)) | |
output, prev_states = self.lstm(embedded, prev_states) | |
logit = self.logit_layer(output).squeeze(1) # [b,|V|] | |
inputs = torch.argmax(logit, 1) | |
decoded_ids.append(inputs) | |
decoded_ids = torch.stack(decoded_ids, dim=1) | |
return decoded_ids | |
class CatVAE(nn.Module): | |
def __init__(self, vocab_size, embedding_size, | |
hidden_size, num_vars, num_classes): | |
super(CatVAE, self).__init__() | |
# Encoder-Decoder for question | |
self.encoder = CatEncoder(embedding_size, hidden_size, | |
num_vars, num_classes) | |
self.decoder = CatDecoder(vocab_size, embedding_size, | |
hidden_size) | |
self.linear_h = nn.Linear(num_vars * num_classes, hidden_size, bias=False) | |
self.linear_c = nn.Linear(num_vars * num_classes, hidden_size, bias=False) | |
self.num_vars = num_vars | |
self.num_classes = num_classes | |
def forward(self, q_ids): | |
sos_q_ids = q_ids[:, :-1] | |
eos_q_ids = q_ids[:, 1:] | |
q_len = torch.sum(torch.sign(eos_q_ids), 1) | |
qz_logits = self.encoder(eos_q_ids, q_len) # exclude [CLS] | |
flatten_logits = qz_logits.view(-1, self.num_classes) | |
# sample categorical variable by gumbel-softmax | |
z_samples = self.gumbel_softmax(flatten_logits, tau=1.0).view(-1, self.num_vars * self.num_classes) | |
init_h = self.linear_h(z_samples).unsqueeze(0) # [1,b,d] | |
init_c = self.linear_c(z_samples).unsqueeze(0) # [1,b,d] | |
init_states = (init_h, init_c) | |
criterion = nn.CrossEntropyLoss(ignore_index=0) | |
logits = self.q_decoder(sos_q_ids, init_states) | |
batch_size, nsteps, _ = logits.size() | |
preds = logits.view(batch_size * nsteps, -1) | |
targets = eos_q_ids.contiguous().view(-1) | |
nll = criterion(preds, targets) | |
# KL(q(z) || p(z)) p(z) ~ Uniform dist | |
log_qz = F.log_softmax(flatten_logits, dim=-1) # [b*num_vars, num_classes] | |
avg_log_qz = torch.exp(log_qz.view(-1, self.num_vars, self.num_classes)) | |
avg_log_qz = torch.log(torch.mean(avg_log_qz, dim=0) + 1e-15) # [num_vars, num_classes] | |
log_uniform_z = torch.log(torch.ones(1, device=log_qz.device) / self.num_classes) | |
qz = torch.exp(avg_log_qz) | |
kl = torch.sum(qz * (avg_log_qz - log_uniform_z), dim=1) | |
avg_kl = kl.mean() | |
# mutual information H(Z) - H(Z|X) | |
mi = self.entropy(avg_log_qz) - self.entropy(log_qz) | |
return nll, avg_kl, mi | |
@staticmethod | |
def gumbel_softmax(logits, tau=1.0, eps=1e-20): | |
u = torch.rand_like(logits) | |
sample = -torch.log(-torch.log(u + eps) + eps) | |
y = logits + sample.to(logits.device) | |
return F.softmax(y / tau, dim=-1) | |
@staticmethod | |
def entropy(log_prob): | |
# log_prob: [b, K] | |
prob = torch.exp(log_prob) | |
h = torch.sum(-log_prob * prob, dim=1) | |
return h.mean() | |
class AnsEncoder(nn.Module): | |
def __init__(self, embedding_size, hidden_size, num_layers, dropout): | |
super(AnsEncoder, self).__init__() | |
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
self.embedding.requires_grad = False | |
self.num_layers = num_layers | |
if self.num_layers == 1: | |
dropout = 0.0 | |
self.lstm = nn.LSTM(embedding_size, hidden_size, dropout=dropout, | |
num_layers=num_layers, bidirectional=True, batch_first=True) | |
def forward(self, ans_ids, ans_len): | |
embedded = self.embedding(ans_ids) | |
packed = pack_padded_sequence(embedded, ans_len, batch_first=True, | |
enforce_sorted=False) | |
_, states = self.lstm(packed) | |
h, c = states | |
_, b, d = h.size() | |
h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d] | |
h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1) | |
c = c.view(self.num_layers, 2, b, d) | |
c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1) | |
concat_states = (h, c) | |
return concat_states | |
class Encoder(nn.Module): | |
def __init__(self, embedding_size, | |
hidden_size, num_layers, dropout, use_tag): | |
super(Encoder, self).__init__() | |
self.use_tag = use_tag | |
self.num_layers = num_layers | |
# tag embedding | |
if use_tag: | |
self.tag_embedding = nn.Embedding(3, 3) | |
lstm_input_size = embedding_size + 3 | |
else: | |
lstm_input_size = embedding_size | |
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
self.embedding.requires_grad = False | |
self.num_layers = num_layers | |
if self.num_layers == 1: | |
dropout = 0.0 | |
self.lstm = nn.LSTM(lstm_input_size, hidden_size, dropout=dropout, | |
num_layers=num_layers, bidirectional=True, batch_first=True) | |
self.linear_trans = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False) | |
self.update_layer = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) | |
self.gate = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) | |
def gated_self_attn(self, queries, memories, mask): | |
# queries: [b,t,d] | |
# memories: [b,t,d] | |
# mask: [b,t] | |
energies = torch.matmul(queries, memories.transpose(1, 2)) # [b, t, t] | |
energies = energies.masked_fill(mask.unsqueeze(1), value=-1e12) | |
scores = F.softmax(energies, dim=2) | |
context = torch.matmul(scores, queries) | |
inputs = torch.cat((queries, context), dim=2) | |
f_t = torch.tanh(self.update_layer(inputs)) | |
g_t = torch.sigmoid(self.gate(inputs)) | |
updated_output = g_t * f_t + (1 - g_t) * queries | |
return updated_output | |
def forward(self, src_seq, src_len, tag_seq): | |
total_length = src_seq.size(1) | |
embedded = self.embedding(src_seq) | |
if self.use_tag and tag_seq is not None: | |
tag_embedded = self.tag_embedding(tag_seq) | |
embedded = torch.cat((embedded, tag_embedded), dim=2) | |
packed = pack_padded_sequence(embedded, src_len, batch_first=True, enforce_sorted=False) | |
self.lstm.flatten_parameters() | |
outputs, states = self.lstm(packed) # states : tuple of [4, b, d] | |
outputs, _ = pad_packed_sequence(outputs, batch_first=True, | |
total_length=total_length) # [b, t, d] | |
h, c = states | |
# self attention | |
zeros = outputs.sum(dim=-1) | |
mask = (zeros == 0).byte() | |
memories = self.linear_trans(outputs) | |
outputs = self.gated_self_attn(outputs, memories, mask) | |
_, b, d = h.size() | |
h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d] | |
h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1) | |
c = c.view(self.num_layers, 2, b, d) | |
c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1) | |
concat_states = (h, c) | |
return outputs, concat_states | |
class Decoder(nn.Module): | |
def __init__(self, embedding_size, vocab_size, enc_size, | |
hidden_size, num_layers, dropout, | |
pointer=False): | |
super(Decoder, self).__init__() | |
self.vocab_size = vocab_size | |
self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
self.embedding.requires_grad = False | |
if num_layers == 1: | |
dropout = 0.0 | |
self.encoder_trans = nn.Linear(enc_size, hidden_size) | |
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
num_layers=num_layers, bidirectional=False, dropout=dropout) | |
self.concat_layer = nn.Linear(2 * hidden_size, hidden_size, bias=False) | |
self.logit_layer = nn.Linear(hidden_size, vocab_size) | |
self.pointer = pointer | |
@staticmethod | |
def attention(query, memories, mask): | |
# query : [b, 1, d] | |
energy = torch.matmul(query, memories.transpose(1, 2)) # [b, 1, t] | |
energy = energy.squeeze(1).masked_fill(mask, value=-1e12) | |
attn_dist = F.softmax(energy, dim=1).unsqueeze(dim=1) # [b, 1, t] | |
context_vector = torch.matmul(attn_dist, memories) # [b, 1, d] | |
return context_vector, energy | |
def get_encoder_features(self, encoder_outputs): | |
return self.encoder_trans(encoder_outputs) | |
def forward(self, q_ids, c_ids, init_states, encoder_outputs, enc_mask): | |
# q_ids : [b,t] | |
# z_samples: [b, M*K] | |
# init_states : [2,b,d] | |
# encoder_outputs : [b,t,d] | |
# init_states : a tuple of [2, b, d] | |
batch_size, max_len = q_ids.size() | |
memories = self.get_encoder_features(encoder_outputs) | |
logits = [] | |
prev_states = init_states | |
self.lstm.flatten_parameters() | |
for i in range(max_len): | |
y_i = q_ids[:, i].unsqueeze(dim=1) | |
embedded = self.embedding(y_i) | |
hidden, states = self.lstm(embedded, prev_states) | |
# encoder-decoder attention | |
context, energy = self.attention(hidden, memories, enc_mask) | |
concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1) | |
logit_input = torch.tanh(self.concat_layer(concat_input)) | |
logit = self.logit_layer(logit_input) # [b, |V|] | |
# maxout pointer network | |
if self.pointer: | |
num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0) | |
zeros = torch.zeros((batch_size, num_oov), device=logit.device) | |
extended_logit = torch.cat((logit, zeros), dim=1) | |
out = torch.zeros_like(extended_logit) - INF | |
out, _ = scatter_max(energy, c_ids, out=out) | |
out = out.masked_fill(out == -INF, 0) | |
logit = extended_logit + out | |
logit = logit.masked_fill(logit == 0, -INF) | |
logits.append(logit) | |
# update prev state and context | |
prev_states = states | |
logits = torch.stack(logits, dim=1) # [b, t, |V|] | |
return logits | |
def decode(self, q_id, c_ids, prev_states, memories, enc_mask): | |
self.lstm.flatten_parameters() | |
embedded = self.embedding(q_id.unsqueeze(1)) | |
batch_size = c_ids.size(0) | |
hidden, states = self.lstm(embedded, prev_states) | |
# attention | |
context, energy = self.attention(hidden, memories, enc_mask) | |
concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1) | |
logit_input = torch.tanh(self.concat_layer(concat_input)) | |
logit = self.logit_layer(logit_input) # [b, |V|] | |
if self.pointer: | |
num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0) | |
zeros = torch.zeros((batch_size, num_oov), device=logit.device) | |
extended_logit = torch.cat((logit, zeros), dim=1) | |
out = torch.zeros_like(extended_logit) - INF | |
out, _ = scatter_max(energy, c_ids, out=out) | |
out = out.masked_fill(out == -INF, 0) | |
logit = extended_logit + out | |
logit = logit.masked_fill(logit == 0, -INF) | |
return logit, states | |
class AttnQG(nn.Module): | |
def __init__(self, vocab_size, embedding_size, hidden_size, vae_hidden_size, num_layers, | |
dropout, num_vars, num_classes, save_path, pointer=False): | |
super(AttnQG, self).__init__() | |
self.num_vars = num_vars | |
self.num_classes = num_classes | |
self.vae_net = CatVAE(vocab_size, | |
embedding_size, | |
vae_hidden_size, | |
num_vars, | |
num_classes) | |
if save_path is not None: | |
state_dict = torch.load(save_path, map_location="cpu") | |
self.vae_net.load_state_dict(state_dict) | |
self.encoder = Encoder(embedding_size, hidden_size, num_layers, dropout, use_tag=True) | |
self.ans_encoder = AnsEncoder(embedding_size, hidden_size, num_layers, dropout) | |
self.decoder = Decoder(embedding_size, vocab_size, 2 * hidden_size, 2 * hidden_size, | |
num_layers, dropout, pointer) | |
self.linear_trans = nn.Linear(4 * hidden_size, 4 * hidden_size) | |
self.prior_net = nn.Linear(4 * hidden_size, num_vars * num_classes) | |
self.linear_h = nn.Linear(num_vars * num_classes, 2 * hidden_size) | |
self.linear_c = nn.Linear(num_vars * num_classes, 2 * hidden_size) | |
def forward(self, c_ids, tag_ids, q_ids, ans_ids, use_prior=False): | |
sos_q_ids = q_ids[:, :-1] | |
eos_q_ids = q_ids[:, 1:] | |
# sample z | |
with torch.no_grad(): | |
q_len = torch.sum(torch.sign(eos_q_ids), 1) | |
qz_logits = self.vae_net.encoder(eos_q_ids, q_len) | |
flatten_logits = qz_logits.view(-1, self.num_classes) | |
# encode passage | |
c_len = torch.sum(torch.sign(c_ids), 1) | |
enc_outputs, states = self.encoder(c_ids, c_len, tag_ids) | |
last_c_hidden = states[0][-1] | |
# encode answer | |
ans_len = torch.sum(torch.sign(ans_ids), 1) | |
ans_states = self.ans_encoder(ans_ids, ans_len) | |
last_ans_states = ans_states[0][-1] | |
last_hidden = torch.cat([last_c_hidden, last_ans_states], -1) | |
# compute prior | |
prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden))) | |
log_pz = F.log_softmax(prior_logits.view(-1, self.num_classes), dim=-1) | |
log_qz = F.log_softmax(flatten_logits, dim=-1).detach() | |
if use_prior: | |
probs = torch.exp(log_pz) | |
else: | |
# z_samples = self.vae_net.gumbel_softmax(flatten_logits, tau) | |
# z_samples = z_samples.view(-1, self.num_vars * self.num_classes) | |
probs = F.softmax(flatten_logits, dim=1) | |
m = Categorical(probs) | |
z_samples = m.sample() | |
z_samples = F.one_hot(z_samples, num_classes=self.num_classes) | |
z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float() | |
init_h = self.linear_h(z_samples.detach()) | |
init_c = self.linear_c(z_samples.detach()) | |
init_h = init_h + states[0] | |
init_c = init_c + states[1] | |
new_states = (init_h, init_c) | |
# KL(q(z|x) || p(z|c)) | |
qz = torch.exp(log_qz) | |
prior_kl = torch.sum(qz * (log_qz - log_pz), dim=1) | |
prior_kl = prior_kl.mean() | |
c_mask = (c_ids == 0).byte() | |
logits = self.decoder(sos_q_ids, c_ids, new_states, enc_outputs, c_mask) | |
# \hat{x} ~ p(x|c,z) | |
decoded_ids = torch.argmax(logits, dim=-1) | |
seq_len = self.get_seq_len(decoded_ids, EOS_ID) | |
vocab_dist = F.softmax(logits, dim=-1) | |
# z ~ p(z|\hat{x}) | |
z_logits = self.vae_net.encoder(vocab_dist, seq_len) | |
flatten_z_logits = z_logits.view(-1, self.num_classes) | |
true_z = z_samples.view(-1, self.num_classes) | |
true_z = torch.argmax(true_z, dim=-1) | |
aux_criterion = nn.CrossEntropyLoss() | |
aux_loss = aux_criterion(flatten_z_logits, true_z) | |
batch_size, nsteps, _ = logits.size() | |
criterion = nn.CrossEntropyLoss(ignore_index=0) | |
preds = logits.view(batch_size * nsteps, -1) | |
targets = eos_q_ids.contiguous().view(-1) | |
nll = criterion(preds, targets) | |
return nll, prior_kl, aux_loss | |
def generate(self, c_ids, tag_ids, ans_ids, sos_ids, max_step): | |
# encode context | |
c_len = torch.sum(torch.sign(c_ids), 1) | |
enc_outputs, states = self.encoder(c_ids, c_len, tag_ids) | |
c_mask = (c_ids == 0).byte() | |
last_c_hidden = states[0][-1] | |
# encode answer | |
ans_len = torch.sum(torch.sign(ans_ids), 1) | |
ans_states = self.ans_encoder(ans_ids, ans_len) | |
last_ans_hidden = ans_states[0][-1] | |
last_hidden = torch.cat([last_c_hidden, last_ans_hidden], -1) | |
# sample z from prior distribution | |
prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden))) | |
prior_logits = prior_logits.view(-1, self.num_classes) | |
prior_prob = F.softmax(prior_logits, dim=-1) # [b*num_vars, num_classes] | |
m = Categorical(prior_prob) | |
z_samples = m.sample() # [b*num_vars] | |
# one-hot vector | |
z_samples = F.one_hot(z_samples, num_classes=self.num_classes) | |
z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float() | |
init_h = self.linear_h(z_samples) | |
init_c = self.linear_c(z_samples) | |
new_h = init_h + states[0] | |
new_c = init_c + states[1] | |
new_states = (new_h, new_c) | |
inputs = sos_ids | |
prev_states = new_states | |
decoded_ids = [] | |
memories = self.decoder.get_encoder_features(enc_outputs) | |
for _ in range(max_step): | |
logit, prev_states = self.decoder.decode(inputs, c_ids, | |
prev_states, memories, | |
c_mask) | |
inputs = torch.argmax(logit, 1) # [b] | |
decoded_ids.append(inputs) | |
decoded_ids = torch.stack(decoded_ids, 1) | |
return decoded_ids | |
@staticmethod | |
def get_seq_len(input_ids, eos_id): | |
# input_ids: [b, t] | |
# eos_id : scalar | |
mask = input_ids == eos_id | |
num_eos = torch.sum(mask, 1) | |
# change Tensor to cpu because torch.argmax works differently in cuda and cpu | |
# but np.argmax is consistent it returns the first index of the maximum element | |
mask = mask.cpu().numpy() | |
indices = np.argmax(mask, 1) | |
# convert numpy array to Tensor | |
seq_len = torch.LongTensor(indices).to(input_ids.device) | |
# in case there is no eos in the sequence | |
max_len = input_ids.size(1) | |
seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1) | |
# +1 for eos | |
seq_len = seq_len + 1 | |
return seq_len |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment