Created
September 28, 2019 01:15
-
-
Save zomux/0e4bb756656482447fff15252ec9883e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nmtlab.models.transformer import Transformer | |
from nmtlab.modules.transformer_modules import TransformerEmbedding | |
from nmtlab.modules.transformer_modules import PositionalEmbedding | |
from nmtlab.modules.transformer_modules import LabelSmoothingKLDivLoss | |
from nmtlab.utils import OPTS | |
from nmtlab.utils import MapDict, TensorMap | |
from mcgen.lib_namt_modules import TransformerCrossEncoder, TransformerEncoder | |
from mcgen.lib_namt_modules import LengthConverter | |
from mcgen.lib_rupdate_model import LatentUpdateModel | |
from mcgen.lib_padding import pad_z, pad_z_with_delta | |
# !!! Deprecated | |
class NAMTUnifiedModel(Transformer): | |
def __init__(self, enc_layers=3, dec_layers=3, **kwargs): | |
self.enc_layers = enc_layers | |
self.dec_layers = dec_layers | |
self.training_criteria = "tok+kl" | |
assert OPTS.bottleneck == "vae" | |
if OPTS.rupdate: | |
self.training_criteria = "loss" | |
super(NAMTUnifiedModel, self).__init__(**kwargs) | |
def prepare(self): | |
# Shared embedding layer | |
max_size = self._src_vocab_size if self._src_vocab_size > self._tgt_vocab_size else self._tgt_vocab_size | |
self.x_embed_layer = TransformerEmbedding(max_size, self.embed_size) | |
self.y_embed_layer = TransformerEmbedding(self._tgt_vocab_size, self.embed_size) | |
self.pos_embed_layer = PositionalEmbedding(self.hidden_size) | |
# Length Transform | |
self.length_converter = LengthConverter() | |
self.length_embed_layer = nn.Embedding(500, self.hidden_size) | |
# encoder and decoder | |
if OPTS.zdim > 0: | |
z_size = OPTS.zdim | |
self.postz_nn = nn.Linear(z_size, self.hidden_size) | |
else: | |
z_size = self.hidden_size | |
self.latent_size = z_size | |
self.xz_encoders = nn.ModuleList() | |
self.xz_softmax = nn.ModuleList() | |
encoder = TransformerEncoder(self.x_embed_layer, self.hidden_size, self.enc_layers) | |
self.xz_encoders.append(encoder) | |
xz_predictor = nn.Linear(self.hidden_size, z_size * 2) | |
self.xz_softmax.append(xz_predictor) | |
self.y_encoder = TransformerEncoder(self.y_embed_layer, self.hidden_size, self.enc_layers) | |
self.yz_encoder = TransformerCrossEncoder(None, self.hidden_size, self.enc_layers) | |
self.y_decoder = TransformerCrossEncoder(None, self.hidden_size, self.dec_layers, skip_connect=True) | |
# Discretization | |
if OPTS.bottleneck == "vae": | |
from mcgen.lib_vae import VAEBottleneck | |
self.bottleneck = VAEBottleneck(self.hidden_size, z_size=z_size) | |
else: | |
raise NotImplementedError | |
# Length prediction | |
self.length_dense = nn.Linear(self.hidden_size, 100) | |
# Expander | |
self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size) | |
self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size, 0) | |
# Latent update model | |
if OPTS.rupdate: | |
for param in self.parameters(): | |
param.requires_grad = False | |
self.rupdate_nn = LatentUpdateModel(self.hidden_size, z_size) | |
self.set_stepwise_training(False) | |
def encode_y(self, x_states, x_mask, y, y_mask): | |
y_states = self.y_encoder(y, y_mask) | |
states = self.yz_encoder(x_states, x_mask, y_states, y_mask) | |
return states | |
def sample_Q(self, states, sampling=True, prior=None): | |
"""Return z and p(z|y,x) | |
""" | |
extra = {} | |
if OPTS.bottleneck == "vae": | |
if OPTS.bindpq: | |
residual_q = prior | |
else: | |
residual_q = None | |
quantized_vector, code_prob = self.bottleneck(states, sampling=sampling, residual_q=residual_q) | |
quantized_vector = self.postz_nn(quantized_vector) | |
else: | |
raise NotImplementedError | |
return quantized_vector, code_prob, extra | |
def compute_length_pred_loss(self, xz_states, z, z_mask, y_mask): | |
y_lens = y_mask.sum(1) - 1 | |
delta = (y_lens - z_mask.sum(1) + 50.).long().clamp(0, 99) | |
mean_z = ((z + xz_states) * z_mask[:, :, None]).sum(1) / z_mask.sum(1)[:, None] | |
logits = self.length_dense(mean_z) | |
length_loss = F.cross_entropy(logits, delta, reduction="mean") | |
length_acc = ((logits.argmax(-1) == delta).float()).mean() | |
length_monitors = { | |
"lenloss": length_loss, | |
"lenacc": length_acc | |
} | |
return length_monitors | |
def compute_vae_KL(self, xz_prob, yz_prob): | |
mu1 = yz_prob[:, :, :self.latent_size] | |
var1 = F.softplus(yz_prob[:, :, self.latent_size:]) | |
mu2 = xz_prob[:, :, :self.latent_size] | |
var2 = F.softplus(xz_prob[:, :, self.latent_size:]) | |
kl = torch.log(var2 / (var1 + 1e-8) + 1e-8) + ( | |
(torch.pow(var1, 2) + torch.pow(mu1 - mu2, 2)) / (2 * torch.pow(var2, 2))) - 0.5 | |
kl = kl.sum(-1) | |
return kl | |
def compute_final_loss(self, yz_prob, xz_prob, x_mask, score_map): | |
""" Register KL divergense and bottleneck loss. | |
""" | |
if not OPTS.withkl: | |
yz_prob = yz_prob.detach() | |
if OPTS.bottleneck == "vae": | |
kl = self.compute_vae_KL(xz_prob, yz_prob) | |
else: | |
raise NotImplementedError | |
if OPTS.klbudget: | |
budget = float(OPTS.budgetn) / 100. | |
if OPTS.annealkl and not OPTS.klft and not OPTS.origft: | |
step = OPTS.trainer.global_step() | |
half_maxsteps = float(OPTS.maxsteps / 2) | |
if step > half_maxsteps: | |
rate = (float(step) - half_maxsteps) / half_maxsteps | |
min_budget = 0.1 | |
budget = min_budget + (budget - min_budget) * (1. - rate) | |
score_map["budget"] = torch.tensor(budget) | |
max_mask = ((kl - budget) > 0.).float() | |
kl = kl * max_mask + (1. - max_mask) * budget | |
if OPTS.sumloss: | |
kl_loss = (kl * x_mask).sum() / x_mask.shape[0] | |
score_map["wkl"] = (kl * x_mask).sum() / x_mask.sum() | |
else: | |
kl_loss = (kl * x_mask).sum() / x_mask.sum() | |
score_map["kl"] = kl_loss | |
# Combine all losses | |
score_map["tokloss"] = score_map["loss"] | |
score_map["tok+kl"] = score_map["loss"] + kl_loss | |
if OPTS.withkl: | |
klweight = float(OPTS.klweight) / 100 | |
shard_loss = score_map["kl"].clone() * klweight | |
else: | |
shard_loss = score_map["kl"].clone() | |
if "neckloss" in score_map: | |
shard_loss += score_map["neckloss"] | |
if "lenloss" in score_map: | |
if OPTS.sumloss: | |
shard_loss += score_map["lenloss"] * float(OPTS.lenweight) | |
else: | |
shard_loss += score_map["lenloss"] * 0.1 | |
score_map["shard_loss"] = shard_loss | |
score_map["loss"] = shard_loss + score_map["tokloss"] | |
return score_map | |
def forward(self, x, y, sampling=False, return_code=False): | |
"""Forward to compute the loss. | |
""" | |
score_map = {} | |
x_mask = torch.ne(x, 0).float() | |
y_mask = torch.ne(y, 0).float() | |
# Compute p(z|x) | |
xz_states = self.xz_encoders[0](x, x_mask) | |
full_xz_states = xz_states | |
xz_prob = self.xz_softmax[0](xz_states) | |
# Compute p(z|y,x) and sample z | |
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask) | |
# Create latents | |
z_mask = x_mask | |
z, yz_prob, bottleneck_scores = self.sample_Q(yz_states, prior=xz_prob) | |
score_map.update(bottleneck_scores) | |
# Comute length loss | |
length_scores = self.compute_length_pred_loss(xz_states, z, z_mask, y_mask) | |
score_map.update(length_scores) | |
z_expand, z_expand_mask = z, z_mask | |
tgt_states_mask = y_mask | |
# Padding z to fit target states | |
z_pad, _ = pad_z(self, z_expand, z_expand_mask, tgt_states_mask) | |
if OPTS.tanhz: | |
z_pad = F.tanh(z_pad) | |
# -------------------------- Decoder ------------------------- | |
decoder_states = self.y_decoder(z_pad, y_mask, full_xz_states, x_mask) | |
# Compute loss | |
decoder_outputs = TensorMap({"final_states": decoder_states}) | |
if OPTS.sumloss: | |
denom = x.shape[0] | |
else: | |
denom = None | |
if self._shard_size is not None and self._shard_size > 0: | |
loss_scores, decoder_tensors, decoder_grads = self.compute_shard_loss( | |
decoder_outputs, y, y_mask, denominator=denom, ignore_first_token=False, backward=False | |
) | |
loss_scores["word_acc"] *= float(y_mask.shape[0]) / y_mask.sum().float() | |
score_map.update(loss_scores) | |
else: | |
logits = self.expand(decoder_outputs) | |
loss = self.compute_loss(logits, y, y_mask, denominator=denom, ignore_first_token=False) | |
acc = self.compute_word_accuracy(logits, y, y_mask, ignore_first_token=False) | |
score_map["loss"] = loss | |
score_map["word_acc"] = acc | |
score_map = self.compute_final_loss(yz_prob, xz_prob, z_mask, score_map) | |
# Backward for shard loss | |
if self._shard_size is not None and self._shard_size > 0 and decoder_tensors is not None: | |
decoder_tensors.append(score_map["shard_loss"]) | |
decoder_grads.append(None) | |
torch.autograd.backward(decoder_tensors, decoder_grads) | |
del score_map["shard_loss"] | |
return score_map | |
def sample_z(self, z_prob): | |
""" Return the quantized vector given probabiliy distribution over z. | |
""" | |
if OPTS.bottleneck == "vae": | |
quantized_vector = z_prob[:, :, :self.latent_size] | |
quantized_vector = self.postz_nn(quantized_vector) | |
return quantized_vector | |
def predict_length(self, xz_states, z, z_mask): | |
mean_z = ((z + xz_states) * z_mask[:, :, None]).sum(1) / z_mask.sum(1)[:, None] | |
logits = self.length_dense(mean_z) | |
delta = logits.argmax(-1) - 50 | |
return delta | |
def translate(self, x, y=None, q=None, xz_states=None): | |
""" Testing code | |
""" | |
x_mask = torch.ne(x, 0).float() | |
# Compute p(z|x) | |
if xz_states is None: | |
xz_states = self.xz_encoders[0](x, x_mask) | |
# Sample a z | |
if q is not None: | |
# Z is provided | |
z = q | |
elif y is not None: | |
# Y is provided | |
y_mask = torch.ne(y, 0).float() | |
x_embeds = self.x_embed_layer(x) | |
yz_states = self.encode_y(x_embeds, x_mask, y, y_mask) | |
_, yz_prob, bottleneck_scores = self.sample_Q(yz_states) | |
z = self.sample_z(yz_prob) | |
else: | |
# Compute prior to get Z | |
xz_prob = self.xz_softmax[0](xz_states) | |
z = self.sample_z(xz_prob) | |
# Predict length | |
if y is None or True: | |
length_delta = self.predict_length(xz_states, z, x_mask) | |
else: | |
length_delta = (y_mask.sum(1) - 1 - x_mask.sum(1)).long() | |
# Padding z to cover the length of y | |
z_pad, z_pad_mask, y_lens = pad_z_with_delta(self, z, x_mask, length_delta + 1) | |
if z_pad.size(1) == 0: | |
return None, y_lens, xz_prob.argmax(-1) | |
# Run decoder to predict the target words | |
decoder_states = self.y_decoder(z_pad, z_pad_mask, xz_states, x_mask) | |
# Get the predictions | |
logits = self.expander_nn(decoder_states) | |
pred = logits.argmax(-1) | |
return pred, y_lens, z, xz_states | |
def compute_Q(self, x, y): | |
"""Forward to compute the loss. | |
""" | |
x_mask = torch.ne(x, 0).float() | |
y_mask = torch.ne(y, 0).float() | |
# Compute p(z|y,x) and sample z | |
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask) | |
z, yz_prob, _ = self.sample_Q(yz_states, sampling=False) | |
return z, yz_prob | |
def load_state_dict(self, state_dict): | |
"""Remove deep generative model weights. | |
""" | |
keys = list(state_dict.keys()) | |
for k in keys: | |
if "xz_encoders.1" in k or "xz_encoders.2" in k or "xz_softmax.1" in k or "xz_softmax.2" in k: | |
del state_dict[k] | |
if OPTS.rupdate: | |
strict = False | |
else: | |
strict = True | |
super(NAMTUnifiedModel, self).load_state_dict(state_dict, strict=strict) | |
def measure_ELBO(self, x, y): | |
"""Measure the ELBO in the inference time.""" | |
x_mask = torch.ne(x, 0).float() | |
y_mask = torch.ne(y, 0).float() | |
# Compute p(z|x) | |
xz_states = self.xz_encoders[0](x, x_mask) | |
xz_prob = self.xz_softmax[0](xz_states) | |
# Compute p(z|y,x) and sample z | |
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask) | |
# Sampling for 20 times | |
likelihood_list = [] | |
for _ in range(20): | |
z, yz_prob, bottleneck_scores = self.sample_Q(yz_states) | |
z_pad, _ = pad_z(self, z, x_mask, y_mask) | |
if OPTS.tanhz: | |
z_pad = F.tanh(z_pad) | |
decoder_states = self.y_decoder(z_pad, y_mask, xz_states, x_mask) | |
logits = self.expander_nn(decoder_states) | |
likelihood = - F.cross_entropy(logits[0], y[0], reduction="sum") | |
likelihood_list.append(likelihood) | |
kl = self.compute_vae_KL(xz_prob, yz_prob).sum() | |
mean_likelihood = sum(likelihood_list) / len(likelihood_list) | |
elbo = mean_likelihood - kl | |
return elbo | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment