Created
August 9, 2019 00:22
-
-
Save seanie12/e840e7cb85cb334ee559de8eabc9723b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pytorch_pretrained_bert import BertModel | |
| from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | |
| from torch_scatter import scatter_max | |
| import numpy as np | |
| from torch.distributions.categorical import Categorical | |
| INF = 1e12 | |
| EOS_ID = 102 | |
| class CatEncoder(nn.Module): | |
| def __init__(self, embedding_size, hidden_size, | |
| num_vars, num_classes): | |
| super(CatEncoder, self).__init__() | |
| self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
| self.embedding.requires_grad = False | |
| self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
| bidirectional=True, num_layers=1) | |
| self.recog_layer = nn.Linear(2 * hidden_size, num_vars * num_classes) | |
| def forward(self, q_ids, q_len): | |
| if q_ids.dim() == 2: | |
| embedded = self.embedding(q_ids) | |
| else: | |
| embedded = self.get_embedding(q_ids) | |
| packed = pack_padded_sequence(embedded, q_len, | |
| batch_first=True, | |
| enforce_sorted=False) | |
| output, states = self.lstm(packed) | |
| output, _ = pad_packed_sequence(output, batch_first=True) | |
| hiddens = states[0] # [2, b, d] | |
| _, b, d = hiddens.size() | |
| concat_hidden = torch.cat([h for h in hiddens], dim=-1) # [b,2*d] | |
| # logits for K categorical variables | |
| qz_logits = self.recog_layer(concat_hidden) | |
| return qz_logits | |
| def get_embedding(self, vocab_dist): | |
| # vocab_dist : [b,t,|V|] | |
| batch_size, nsteps, _ = vocab_dist.size() | |
| token_type_ids = torch.zeros((batch_size, nsteps), dtype=torch.long).to(vocab_dist.device) | |
| position_ids = torch.arange(nsteps, dtype=torch.long, device=vocab_dist.device) | |
| position_ids = position_ids.unsqueeze(0).repeat([batch_size, 1]) | |
| embedding_matrix = self.embedding.word_embeddings.weight | |
| word_embeddings = torch.matmul(vocab_dist, embedding_matrix) | |
| position_embeddings = self.embedding.position_embeddings(position_ids) | |
| token_type_embeddings = self.embedding.token_type_embeddings(token_type_ids) | |
| embeddings = word_embeddings + position_embeddings + token_type_embeddings | |
| embeddings = self.embedding.LayerNorm(embeddings) | |
| embeddings = self.embedding.dropout(embeddings) | |
| return embeddings | |
| class CatDecoder(nn.Module): | |
| def __init__(self, vocab_size, embedding_size, hidden_size): | |
| super(CatDecoder, self).__init__() | |
| self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
| self.embedding.requires_grad = False | |
| self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
| num_layers=1) | |
| self.logit_layer = nn.Linear(hidden_size, vocab_size) | |
| def forward(self, q_ids, init_states): | |
| batch_size, max_len = q_ids.size() | |
| logits = [] | |
| states = init_states | |
| for i in range(max_len): | |
| q_i = q_ids[:, i] | |
| embedded = self.embedding(q_i.unsqueeze(1)) | |
| hidden, states = self.lstm(embedded, states) | |
| logit = self.logit_layer(hidden) # [b,1,|V|] | |
| logits.append(logit) | |
| logits = torch.cat(logits, dim=1) | |
| return logits | |
| def decode(self, sos_tokens, init_states, max_step): | |
| inputs = sos_tokens | |
| prev_states = init_states | |
| decoded_ids = [] | |
| for i in range(max_step): | |
| embedded = self.embedding(inputs.unsqueeze(1)) | |
| output, prev_states = self.lstm(embedded, prev_states) | |
| logit = self.logit_layer(output).squeeze(1) # [b,|V|] | |
| inputs = torch.argmax(logit, 1) | |
| decoded_ids.append(inputs) | |
| decoded_ids = torch.stack(decoded_ids, dim=1) | |
| return decoded_ids | |
| class CatVAE(nn.Module): | |
| def __init__(self, vocab_size, embedding_size, | |
| hidden_size, num_vars, num_classes): | |
| super(CatVAE, self).__init__() | |
| # Encoder-Decoder for question | |
| self.encoder = CatEncoder(embedding_size, hidden_size, | |
| num_vars, num_classes) | |
| self.decoder = CatDecoder(vocab_size, embedding_size, | |
| hidden_size) | |
| self.linear_h = nn.Linear(num_vars * num_classes, hidden_size, bias=False) | |
| self.linear_c = nn.Linear(num_vars * num_classes, hidden_size, bias=False) | |
| self.num_vars = num_vars | |
| self.num_classes = num_classes | |
| def forward(self, q_ids): | |
| sos_q_ids = q_ids[:, :-1] | |
| eos_q_ids = q_ids[:, 1:] | |
| q_len = torch.sum(torch.sign(eos_q_ids), 1) | |
| qz_logits = self.encoder(eos_q_ids, q_len) # exclude [CLS] | |
| flatten_logits = qz_logits.view(-1, self.num_classes) | |
| # sample categorical variable by gumbel-softmax | |
| z_samples = self.gumbel_softmax(flatten_logits, tau=1.0).view(-1, self.num_vars * self.num_classes) | |
| init_h = self.linear_h(z_samples).unsqueeze(0) # [1,b,d] | |
| init_c = self.linear_c(z_samples).unsqueeze(0) # [1,b,d] | |
| init_states = (init_h, init_c) | |
| criterion = nn.CrossEntropyLoss(ignore_index=0) | |
| logits = self.q_decoder(sos_q_ids, init_states) | |
| batch_size, nsteps, _ = logits.size() | |
| preds = logits.view(batch_size * nsteps, -1) | |
| targets = eos_q_ids.contiguous().view(-1) | |
| nll = criterion(preds, targets) | |
| # KL(q(z) || p(z)) p(z) ~ Uniform dist | |
| log_qz = F.log_softmax(flatten_logits, dim=-1) # [b*num_vars, num_classes] | |
| avg_log_qz = torch.exp(log_qz.view(-1, self.num_vars, self.num_classes)) | |
| avg_log_qz = torch.log(torch.mean(avg_log_qz, dim=0) + 1e-15) # [num_vars, num_classes] | |
| log_uniform_z = torch.log(torch.ones(1, device=log_qz.device) / self.num_classes) | |
| qz = torch.exp(avg_log_qz) | |
| kl = torch.sum(qz * (avg_log_qz - log_uniform_z), dim=1) | |
| avg_kl = kl.mean() | |
| # mutual information H(Z) - H(Z|X) | |
| mi = self.entropy(avg_log_qz) - self.entropy(log_qz) | |
| return nll, avg_kl, mi | |
| @staticmethod | |
| def gumbel_softmax(logits, tau=1.0, eps=1e-20): | |
| u = torch.rand_like(logits) | |
| sample = -torch.log(-torch.log(u + eps) + eps) | |
| y = logits + sample.to(logits.device) | |
| return F.softmax(y / tau, dim=-1) | |
| @staticmethod | |
| def entropy(log_prob): | |
| # log_prob: [b, K] | |
| prob = torch.exp(log_prob) | |
| h = torch.sum(-log_prob * prob, dim=1) | |
| return h.mean() | |
| class AnsEncoder(nn.Module): | |
| def __init__(self, embedding_size, hidden_size, num_layers, dropout): | |
| super(AnsEncoder, self).__init__() | |
| self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
| self.embedding.requires_grad = False | |
| self.num_layers = num_layers | |
| if self.num_layers == 1: | |
| dropout = 0.0 | |
| self.lstm = nn.LSTM(embedding_size, hidden_size, dropout=dropout, | |
| num_layers=num_layers, bidirectional=True, batch_first=True) | |
| def forward(self, ans_ids, ans_len): | |
| embedded = self.embedding(ans_ids) | |
| packed = pack_padded_sequence(embedded, ans_len, batch_first=True, | |
| enforce_sorted=False) | |
| _, states = self.lstm(packed) | |
| h, c = states | |
| _, b, d = h.size() | |
| h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d] | |
| h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1) | |
| c = c.view(self.num_layers, 2, b, d) | |
| c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1) | |
| concat_states = (h, c) | |
| return concat_states | |
| class Encoder(nn.Module): | |
| def __init__(self, embedding_size, | |
| hidden_size, num_layers, dropout, use_tag): | |
| super(Encoder, self).__init__() | |
| self.use_tag = use_tag | |
| self.num_layers = num_layers | |
| # tag embedding | |
| if use_tag: | |
| self.tag_embedding = nn.Embedding(3, 3) | |
| lstm_input_size = embedding_size + 3 | |
| else: | |
| lstm_input_size = embedding_size | |
| self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
| self.embedding.requires_grad = False | |
| self.num_layers = num_layers | |
| if self.num_layers == 1: | |
| dropout = 0.0 | |
| self.lstm = nn.LSTM(lstm_input_size, hidden_size, dropout=dropout, | |
| num_layers=num_layers, bidirectional=True, batch_first=True) | |
| self.linear_trans = nn.Linear(2 * hidden_size, 2 * hidden_size, bias=False) | |
| self.update_layer = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) | |
| self.gate = nn.Linear(4 * hidden_size, 2 * hidden_size, bias=False) | |
| def gated_self_attn(self, queries, memories, mask): | |
| # queries: [b,t,d] | |
| # memories: [b,t,d] | |
| # mask: [b,t] | |
| energies = torch.matmul(queries, memories.transpose(1, 2)) # [b, t, t] | |
| energies = energies.masked_fill(mask.unsqueeze(1), value=-1e12) | |
| scores = F.softmax(energies, dim=2) | |
| context = torch.matmul(scores, queries) | |
| inputs = torch.cat((queries, context), dim=2) | |
| f_t = torch.tanh(self.update_layer(inputs)) | |
| g_t = torch.sigmoid(self.gate(inputs)) | |
| updated_output = g_t * f_t + (1 - g_t) * queries | |
| return updated_output | |
| def forward(self, src_seq, src_len, tag_seq): | |
| total_length = src_seq.size(1) | |
| embedded = self.embedding(src_seq) | |
| if self.use_tag and tag_seq is not None: | |
| tag_embedded = self.tag_embedding(tag_seq) | |
| embedded = torch.cat((embedded, tag_embedded), dim=2) | |
| packed = pack_padded_sequence(embedded, src_len, batch_first=True, enforce_sorted=False) | |
| self.lstm.flatten_parameters() | |
| outputs, states = self.lstm(packed) # states : tuple of [4, b, d] | |
| outputs, _ = pad_packed_sequence(outputs, batch_first=True, | |
| total_length=total_length) # [b, t, d] | |
| h, c = states | |
| # self attention | |
| zeros = outputs.sum(dim=-1) | |
| mask = (zeros == 0).byte() | |
| memories = self.linear_trans(outputs) | |
| outputs = self.gated_self_attn(outputs, memories, mask) | |
| _, b, d = h.size() | |
| h = h.view(self.num_layers, 2, b, d) # [n_layers, bi, b, d] | |
| h = torch.cat((h[:, 0, :, :], h[:, 1, :, :]), dim=-1) | |
| c = c.view(self.num_layers, 2, b, d) | |
| c = torch.cat((c[:, 0, :, :], c[:, 1, :, :]), dim=-1) | |
| concat_states = (h, c) | |
| return outputs, concat_states | |
| class Decoder(nn.Module): | |
| def __init__(self, embedding_size, vocab_size, enc_size, | |
| hidden_size, num_layers, dropout, | |
| pointer=False): | |
| super(Decoder, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.embedding = BertModel.from_pretrained("bert-base-uncased").embeddings | |
| self.embedding.requires_grad = False | |
| if num_layers == 1: | |
| dropout = 0.0 | |
| self.encoder_trans = nn.Linear(enc_size, hidden_size) | |
| self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True, | |
| num_layers=num_layers, bidirectional=False, dropout=dropout) | |
| self.concat_layer = nn.Linear(2 * hidden_size, hidden_size, bias=False) | |
| self.logit_layer = nn.Linear(hidden_size, vocab_size) | |
| self.pointer = pointer | |
| @staticmethod | |
| def attention(query, memories, mask): | |
| # query : [b, 1, d] | |
| energy = torch.matmul(query, memories.transpose(1, 2)) # [b, 1, t] | |
| energy = energy.squeeze(1).masked_fill(mask, value=-1e12) | |
| attn_dist = F.softmax(energy, dim=1).unsqueeze(dim=1) # [b, 1, t] | |
| context_vector = torch.matmul(attn_dist, memories) # [b, 1, d] | |
| return context_vector, energy | |
| def get_encoder_features(self, encoder_outputs): | |
| return self.encoder_trans(encoder_outputs) | |
| def forward(self, q_ids, c_ids, init_states, encoder_outputs, enc_mask): | |
| # q_ids : [b,t] | |
| # z_samples: [b, M*K] | |
| # init_states : [2,b,d] | |
| # encoder_outputs : [b,t,d] | |
| # init_states : a tuple of [2, b, d] | |
| batch_size, max_len = q_ids.size() | |
| memories = self.get_encoder_features(encoder_outputs) | |
| logits = [] | |
| prev_states = init_states | |
| self.lstm.flatten_parameters() | |
| for i in range(max_len): | |
| y_i = q_ids[:, i].unsqueeze(dim=1) | |
| embedded = self.embedding(y_i) | |
| hidden, states = self.lstm(embedded, prev_states) | |
| # encoder-decoder attention | |
| context, energy = self.attention(hidden, memories, enc_mask) | |
| concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1) | |
| logit_input = torch.tanh(self.concat_layer(concat_input)) | |
| logit = self.logit_layer(logit_input) # [b, |V|] | |
| # maxout pointer network | |
| if self.pointer: | |
| num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0) | |
| zeros = torch.zeros((batch_size, num_oov), device=logit.device) | |
| extended_logit = torch.cat((logit, zeros), dim=1) | |
| out = torch.zeros_like(extended_logit) - INF | |
| out, _ = scatter_max(energy, c_ids, out=out) | |
| out = out.masked_fill(out == -INF, 0) | |
| logit = extended_logit + out | |
| logit = logit.masked_fill(logit == 0, -INF) | |
| logits.append(logit) | |
| # update prev state and context | |
| prev_states = states | |
| logits = torch.stack(logits, dim=1) # [b, t, |V|] | |
| return logits | |
| def decode(self, q_id, c_ids, prev_states, memories, enc_mask): | |
| self.lstm.flatten_parameters() | |
| embedded = self.embedding(q_id.unsqueeze(1)) | |
| batch_size = c_ids.size(0) | |
| hidden, states = self.lstm(embedded, prev_states) | |
| # attention | |
| context, energy = self.attention(hidden, memories, enc_mask) | |
| concat_input = torch.cat((hidden, context), dim=2).squeeze(dim=1) | |
| logit_input = torch.tanh(self.concat_layer(concat_input)) | |
| logit = self.logit_layer(logit_input) # [b, |V|] | |
| if self.pointer: | |
| num_oov = max(torch.max(c_ids - self.vocab_size + 1), 0) | |
| zeros = torch.zeros((batch_size, num_oov), device=logit.device) | |
| extended_logit = torch.cat((logit, zeros), dim=1) | |
| out = torch.zeros_like(extended_logit) - INF | |
| out, _ = scatter_max(energy, c_ids, out=out) | |
| out = out.masked_fill(out == -INF, 0) | |
| logit = extended_logit + out | |
| logit = logit.masked_fill(logit == 0, -INF) | |
| return logit, states | |
| class AttnQG(nn.Module): | |
| def __init__(self, vocab_size, embedding_size, hidden_size, vae_hidden_size, num_layers, | |
| dropout, num_vars, num_classes, save_path, pointer=False): | |
| super(AttnQG, self).__init__() | |
| self.num_vars = num_vars | |
| self.num_classes = num_classes | |
| self.vae_net = CatVAE(vocab_size, | |
| embedding_size, | |
| vae_hidden_size, | |
| num_vars, | |
| num_classes) | |
| if save_path is not None: | |
| state_dict = torch.load(save_path, map_location="cpu") | |
| self.vae_net.load_state_dict(state_dict) | |
| self.encoder = Encoder(embedding_size, hidden_size, num_layers, dropout, use_tag=True) | |
| self.ans_encoder = AnsEncoder(embedding_size, hidden_size, num_layers, dropout) | |
| self.decoder = Decoder(embedding_size, vocab_size, 2 * hidden_size, 2 * hidden_size, | |
| num_layers, dropout, pointer) | |
| self.linear_trans = nn.Linear(4 * hidden_size, 4 * hidden_size) | |
| self.prior_net = nn.Linear(4 * hidden_size, num_vars * num_classes) | |
| self.linear_h = nn.Linear(num_vars * num_classes, 2 * hidden_size) | |
| self.linear_c = nn.Linear(num_vars * num_classes, 2 * hidden_size) | |
| def forward(self, c_ids, tag_ids, q_ids, ans_ids, use_prior=False): | |
| sos_q_ids = q_ids[:, :-1] | |
| eos_q_ids = q_ids[:, 1:] | |
| # sample z | |
| with torch.no_grad(): | |
| q_len = torch.sum(torch.sign(eos_q_ids), 1) | |
| qz_logits = self.vae_net.encoder(eos_q_ids, q_len) | |
| flatten_logits = qz_logits.view(-1, self.num_classes) | |
| # encode passage | |
| c_len = torch.sum(torch.sign(c_ids), 1) | |
| enc_outputs, states = self.encoder(c_ids, c_len, tag_ids) | |
| last_c_hidden = states[0][-1] | |
| # encode answer | |
| ans_len = torch.sum(torch.sign(ans_ids), 1) | |
| ans_states = self.ans_encoder(ans_ids, ans_len) | |
| last_ans_states = ans_states[0][-1] | |
| last_hidden = torch.cat([last_c_hidden, last_ans_states], -1) | |
| # compute prior | |
| prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden))) | |
| log_pz = F.log_softmax(prior_logits.view(-1, self.num_classes), dim=-1) | |
| log_qz = F.log_softmax(flatten_logits, dim=-1).detach() | |
| if use_prior: | |
| probs = torch.exp(log_pz) | |
| else: | |
| # z_samples = self.vae_net.gumbel_softmax(flatten_logits, tau) | |
| # z_samples = z_samples.view(-1, self.num_vars * self.num_classes) | |
| probs = F.softmax(flatten_logits, dim=1) | |
| m = Categorical(probs) | |
| z_samples = m.sample() | |
| z_samples = F.one_hot(z_samples, num_classes=self.num_classes) | |
| z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float() | |
| init_h = self.linear_h(z_samples.detach()) | |
| init_c = self.linear_c(z_samples.detach()) | |
| init_h = init_h + states[0] | |
| init_c = init_c + states[1] | |
| new_states = (init_h, init_c) | |
| # KL(q(z|x) || p(z|c)) | |
| qz = torch.exp(log_qz) | |
| prior_kl = torch.sum(qz * (log_qz - log_pz), dim=1) | |
| prior_kl = prior_kl.mean() | |
| c_mask = (c_ids == 0).byte() | |
| logits = self.decoder(sos_q_ids, c_ids, new_states, enc_outputs, c_mask) | |
| # \hat{x} ~ p(x|c,z) | |
| decoded_ids = torch.argmax(logits, dim=-1) | |
| seq_len = self.get_seq_len(decoded_ids, EOS_ID) | |
| vocab_dist = F.softmax(logits, dim=-1) | |
| # z ~ p(z|\hat{x}) | |
| z_logits = self.vae_net.encoder(vocab_dist, seq_len) | |
| flatten_z_logits = z_logits.view(-1, self.num_classes) | |
| true_z = z_samples.view(-1, self.num_classes) | |
| true_z = torch.argmax(true_z, dim=-1) | |
| aux_criterion = nn.CrossEntropyLoss() | |
| aux_loss = aux_criterion(flatten_z_logits, true_z) | |
| batch_size, nsteps, _ = logits.size() | |
| criterion = nn.CrossEntropyLoss(ignore_index=0) | |
| preds = logits.view(batch_size * nsteps, -1) | |
| targets = eos_q_ids.contiguous().view(-1) | |
| nll = criterion(preds, targets) | |
| return nll, prior_kl, aux_loss | |
| def generate(self, c_ids, tag_ids, ans_ids, sos_ids, max_step): | |
| # encode context | |
| c_len = torch.sum(torch.sign(c_ids), 1) | |
| enc_outputs, states = self.encoder(c_ids, c_len, tag_ids) | |
| c_mask = (c_ids == 0).byte() | |
| last_c_hidden = states[0][-1] | |
| # encode answer | |
| ans_len = torch.sum(torch.sign(ans_ids), 1) | |
| ans_states = self.ans_encoder(ans_ids, ans_len) | |
| last_ans_hidden = ans_states[0][-1] | |
| last_hidden = torch.cat([last_c_hidden, last_ans_hidden], -1) | |
| # sample z from prior distribution | |
| prior_logits = self.prior_net(torch.relu(self.linear_trans(last_hidden))) | |
| prior_logits = prior_logits.view(-1, self.num_classes) | |
| prior_prob = F.softmax(prior_logits, dim=-1) # [b*num_vars, num_classes] | |
| m = Categorical(prior_prob) | |
| z_samples = m.sample() # [b*num_vars] | |
| # one-hot vector | |
| z_samples = F.one_hot(z_samples, num_classes=self.num_classes) | |
| z_samples = z_samples.view(-1, self.num_vars * self.num_classes).float() | |
| init_h = self.linear_h(z_samples) | |
| init_c = self.linear_c(z_samples) | |
| new_h = init_h + states[0] | |
| new_c = init_c + states[1] | |
| new_states = (new_h, new_c) | |
| inputs = sos_ids | |
| prev_states = new_states | |
| decoded_ids = [] | |
| memories = self.decoder.get_encoder_features(enc_outputs) | |
| for _ in range(max_step): | |
| logit, prev_states = self.decoder.decode(inputs, c_ids, | |
| prev_states, memories, | |
| c_mask) | |
| inputs = torch.argmax(logit, 1) # [b] | |
| decoded_ids.append(inputs) | |
| decoded_ids = torch.stack(decoded_ids, 1) | |
| return decoded_ids | |
| @staticmethod | |
| def get_seq_len(input_ids, eos_id): | |
| # input_ids: [b, t] | |
| # eos_id : scalar | |
| mask = input_ids == eos_id | |
| num_eos = torch.sum(mask, 1) | |
| # change Tensor to cpu because torch.argmax works differently in cuda and cpu | |
| # but np.argmax is consistent it returns the first index of the maximum element | |
| mask = mask.cpu().numpy() | |
| indices = np.argmax(mask, 1) | |
| # convert numpy array to Tensor | |
| seq_len = torch.LongTensor(indices).to(input_ids.device) | |
| # in case there is no eos in the sequence | |
| max_len = input_ids.size(1) | |
| seq_len = seq_len.masked_fill(num_eos == 0, max_len - 1) | |
| # +1 for eos | |
| seq_len = seq_len + 1 | |
| return seq_len |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment