Skip to content

Instantly share code, notes, and snippets.

@zomux
Created September 28, 2019 01:15
Show Gist options
  • Save zomux/0e4bb756656482447fff15252ec9883e to your computer and use it in GitHub Desktop.
Save zomux/0e4bb756656482447fff15252ec9883e to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from nmtlab.models.transformer import Transformer
from nmtlab.modules.transformer_modules import TransformerEmbedding
from nmtlab.modules.transformer_modules import PositionalEmbedding
from nmtlab.modules.transformer_modules import LabelSmoothingKLDivLoss
from nmtlab.utils import OPTS
from nmtlab.utils import MapDict, TensorMap
from mcgen.lib_namt_modules import TransformerCrossEncoder, TransformerEncoder
from mcgen.lib_namt_modules import LengthConverter
from mcgen.lib_rupdate_model import LatentUpdateModel
from mcgen.lib_padding import pad_z, pad_z_with_delta
# !!! Deprecated
class NAMTUnifiedModel(Transformer):
def __init__(self, enc_layers=3, dec_layers=3, **kwargs):
self.enc_layers = enc_layers
self.dec_layers = dec_layers
self.training_criteria = "tok+kl"
assert OPTS.bottleneck == "vae"
if OPTS.rupdate:
self.training_criteria = "loss"
super(NAMTUnifiedModel, self).__init__(**kwargs)
def prepare(self):
# Shared embedding layer
max_size = self._src_vocab_size if self._src_vocab_size > self._tgt_vocab_size else self._tgt_vocab_size
self.x_embed_layer = TransformerEmbedding(max_size, self.embed_size)
self.y_embed_layer = TransformerEmbedding(self._tgt_vocab_size, self.embed_size)
self.pos_embed_layer = PositionalEmbedding(self.hidden_size)
# Length Transform
self.length_converter = LengthConverter()
self.length_embed_layer = nn.Embedding(500, self.hidden_size)
# encoder and decoder
if OPTS.zdim > 0:
z_size = OPTS.zdim
self.postz_nn = nn.Linear(z_size, self.hidden_size)
else:
z_size = self.hidden_size
self.latent_size = z_size
self.xz_encoders = nn.ModuleList()
self.xz_softmax = nn.ModuleList()
encoder = TransformerEncoder(self.x_embed_layer, self.hidden_size, self.enc_layers)
self.xz_encoders.append(encoder)
xz_predictor = nn.Linear(self.hidden_size, z_size * 2)
self.xz_softmax.append(xz_predictor)
self.y_encoder = TransformerEncoder(self.y_embed_layer, self.hidden_size, self.enc_layers)
self.yz_encoder = TransformerCrossEncoder(None, self.hidden_size, self.enc_layers)
self.y_decoder = TransformerCrossEncoder(None, self.hidden_size, self.dec_layers, skip_connect=True)
# Discretization
if OPTS.bottleneck == "vae":
from mcgen.lib_vae import VAEBottleneck
self.bottleneck = VAEBottleneck(self.hidden_size, z_size=z_size)
else:
raise NotImplementedError
# Length prediction
self.length_dense = nn.Linear(self.hidden_size, 100)
# Expander
self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size)
self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size, 0)
# Latent update model
if OPTS.rupdate:
for param in self.parameters():
param.requires_grad = False
self.rupdate_nn = LatentUpdateModel(self.hidden_size, z_size)
self.set_stepwise_training(False)
def encode_y(self, x_states, x_mask, y, y_mask):
y_states = self.y_encoder(y, y_mask)
states = self.yz_encoder(x_states, x_mask, y_states, y_mask)
return states
def sample_Q(self, states, sampling=True, prior=None):
"""Return z and p(z|y,x)
"""
extra = {}
if OPTS.bottleneck == "vae":
if OPTS.bindpq:
residual_q = prior
else:
residual_q = None
quantized_vector, code_prob = self.bottleneck(states, sampling=sampling, residual_q=residual_q)
quantized_vector = self.postz_nn(quantized_vector)
else:
raise NotImplementedError
return quantized_vector, code_prob, extra
def compute_length_pred_loss(self, xz_states, z, z_mask, y_mask):
y_lens = y_mask.sum(1) - 1
delta = (y_lens - z_mask.sum(1) + 50.).long().clamp(0, 99)
mean_z = ((z + xz_states) * z_mask[:, :, None]).sum(1) / z_mask.sum(1)[:, None]
logits = self.length_dense(mean_z)
length_loss = F.cross_entropy(logits, delta, reduction="mean")
length_acc = ((logits.argmax(-1) == delta).float()).mean()
length_monitors = {
"lenloss": length_loss,
"lenacc": length_acc
}
return length_monitors
def compute_vae_KL(self, xz_prob, yz_prob):
mu1 = yz_prob[:, :, :self.latent_size]
var1 = F.softplus(yz_prob[:, :, self.latent_size:])
mu2 = xz_prob[:, :, :self.latent_size]
var2 = F.softplus(xz_prob[:, :, self.latent_size:])
kl = torch.log(var2 / (var1 + 1e-8) + 1e-8) + (
(torch.pow(var1, 2) + torch.pow(mu1 - mu2, 2)) / (2 * torch.pow(var2, 2))) - 0.5
kl = kl.sum(-1)
return kl
def compute_final_loss(self, yz_prob, xz_prob, x_mask, score_map):
""" Register KL divergense and bottleneck loss.
"""
if not OPTS.withkl:
yz_prob = yz_prob.detach()
if OPTS.bottleneck == "vae":
kl = self.compute_vae_KL(xz_prob, yz_prob)
else:
raise NotImplementedError
if OPTS.klbudget:
budget = float(OPTS.budgetn) / 100.
if OPTS.annealkl and not OPTS.klft and not OPTS.origft:
step = OPTS.trainer.global_step()
half_maxsteps = float(OPTS.maxsteps / 2)
if step > half_maxsteps:
rate = (float(step) - half_maxsteps) / half_maxsteps
min_budget = 0.1
budget = min_budget + (budget - min_budget) * (1. - rate)
score_map["budget"] = torch.tensor(budget)
max_mask = ((kl - budget) > 0.).float()
kl = kl * max_mask + (1. - max_mask) * budget
if OPTS.sumloss:
kl_loss = (kl * x_mask).sum() / x_mask.shape[0]
score_map["wkl"] = (kl * x_mask).sum() / x_mask.sum()
else:
kl_loss = (kl * x_mask).sum() / x_mask.sum()
score_map["kl"] = kl_loss
# Combine all losses
score_map["tokloss"] = score_map["loss"]
score_map["tok+kl"] = score_map["loss"] + kl_loss
if OPTS.withkl:
klweight = float(OPTS.klweight) / 100
shard_loss = score_map["kl"].clone() * klweight
else:
shard_loss = score_map["kl"].clone()
if "neckloss" in score_map:
shard_loss += score_map["neckloss"]
if "lenloss" in score_map:
if OPTS.sumloss:
shard_loss += score_map["lenloss"] * float(OPTS.lenweight)
else:
shard_loss += score_map["lenloss"] * 0.1
score_map["shard_loss"] = shard_loss
score_map["loss"] = shard_loss + score_map["tokloss"]
return score_map
def forward(self, x, y, sampling=False, return_code=False):
"""Forward to compute the loss.
"""
score_map = {}
x_mask = torch.ne(x, 0).float()
y_mask = torch.ne(y, 0).float()
# Compute p(z|x)
xz_states = self.xz_encoders[0](x, x_mask)
full_xz_states = xz_states
xz_prob = self.xz_softmax[0](xz_states)
# Compute p(z|y,x) and sample z
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask)
# Create latents
z_mask = x_mask
z, yz_prob, bottleneck_scores = self.sample_Q(yz_states, prior=xz_prob)
score_map.update(bottleneck_scores)
# Comute length loss
length_scores = self.compute_length_pred_loss(xz_states, z, z_mask, y_mask)
score_map.update(length_scores)
z_expand, z_expand_mask = z, z_mask
tgt_states_mask = y_mask
# Padding z to fit target states
z_pad, _ = pad_z(self, z_expand, z_expand_mask, tgt_states_mask)
if OPTS.tanhz:
z_pad = F.tanh(z_pad)
# -------------------------- Decoder -------------------------
decoder_states = self.y_decoder(z_pad, y_mask, full_xz_states, x_mask)
# Compute loss
decoder_outputs = TensorMap({"final_states": decoder_states})
if OPTS.sumloss:
denom = x.shape[0]
else:
denom = None
if self._shard_size is not None and self._shard_size > 0:
loss_scores, decoder_tensors, decoder_grads = self.compute_shard_loss(
decoder_outputs, y, y_mask, denominator=denom, ignore_first_token=False, backward=False
)
loss_scores["word_acc"] *= float(y_mask.shape[0]) / y_mask.sum().float()
score_map.update(loss_scores)
else:
logits = self.expand(decoder_outputs)
loss = self.compute_loss(logits, y, y_mask, denominator=denom, ignore_first_token=False)
acc = self.compute_word_accuracy(logits, y, y_mask, ignore_first_token=False)
score_map["loss"] = loss
score_map["word_acc"] = acc
score_map = self.compute_final_loss(yz_prob, xz_prob, z_mask, score_map)
# Backward for shard loss
if self._shard_size is not None and self._shard_size > 0 and decoder_tensors is not None:
decoder_tensors.append(score_map["shard_loss"])
decoder_grads.append(None)
torch.autograd.backward(decoder_tensors, decoder_grads)
del score_map["shard_loss"]
return score_map
def sample_z(self, z_prob):
""" Return the quantized vector given probabiliy distribution over z.
"""
if OPTS.bottleneck == "vae":
quantized_vector = z_prob[:, :, :self.latent_size]
quantized_vector = self.postz_nn(quantized_vector)
return quantized_vector
def predict_length(self, xz_states, z, z_mask):
mean_z = ((z + xz_states) * z_mask[:, :, None]).sum(1) / z_mask.sum(1)[:, None]
logits = self.length_dense(mean_z)
delta = logits.argmax(-1) - 50
return delta
def translate(self, x, y=None, q=None, xz_states=None):
""" Testing code
"""
x_mask = torch.ne(x, 0).float()
# Compute p(z|x)
if xz_states is None:
xz_states = self.xz_encoders[0](x, x_mask)
# Sample a z
if q is not None:
# Z is provided
z = q
elif y is not None:
# Y is provided
y_mask = torch.ne(y, 0).float()
x_embeds = self.x_embed_layer(x)
yz_states = self.encode_y(x_embeds, x_mask, y, y_mask)
_, yz_prob, bottleneck_scores = self.sample_Q(yz_states)
z = self.sample_z(yz_prob)
else:
# Compute prior to get Z
xz_prob = self.xz_softmax[0](xz_states)
z = self.sample_z(xz_prob)
# Predict length
if y is None or True:
length_delta = self.predict_length(xz_states, z, x_mask)
else:
length_delta = (y_mask.sum(1) - 1 - x_mask.sum(1)).long()
# Padding z to cover the length of y
z_pad, z_pad_mask, y_lens = pad_z_with_delta(self, z, x_mask, length_delta + 1)
if z_pad.size(1) == 0:
return None, y_lens, xz_prob.argmax(-1)
# Run decoder to predict the target words
decoder_states = self.y_decoder(z_pad, z_pad_mask, xz_states, x_mask)
# Get the predictions
logits = self.expander_nn(decoder_states)
pred = logits.argmax(-1)
return pred, y_lens, z, xz_states
def compute_Q(self, x, y):
"""Forward to compute the loss.
"""
x_mask = torch.ne(x, 0).float()
y_mask = torch.ne(y, 0).float()
# Compute p(z|y,x) and sample z
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask)
z, yz_prob, _ = self.sample_Q(yz_states, sampling=False)
return z, yz_prob
def load_state_dict(self, state_dict):
"""Remove deep generative model weights.
"""
keys = list(state_dict.keys())
for k in keys:
if "xz_encoders.1" in k or "xz_encoders.2" in k or "xz_softmax.1" in k or "xz_softmax.2" in k:
del state_dict[k]
if OPTS.rupdate:
strict = False
else:
strict = True
super(NAMTUnifiedModel, self).load_state_dict(state_dict, strict=strict)
def measure_ELBO(self, x, y):
"""Measure the ELBO in the inference time."""
x_mask = torch.ne(x, 0).float()
y_mask = torch.ne(y, 0).float()
# Compute p(z|x)
xz_states = self.xz_encoders[0](x, x_mask)
xz_prob = self.xz_softmax[0](xz_states)
# Compute p(z|y,x) and sample z
yz_states = self.encode_y(self.x_embed_layer(x), x_mask, y, y_mask)
# Sampling for 20 times
likelihood_list = []
for _ in range(20):
z, yz_prob, bottleneck_scores = self.sample_Q(yz_states)
z_pad, _ = pad_z(self, z, x_mask, y_mask)
if OPTS.tanhz:
z_pad = F.tanh(z_pad)
decoder_states = self.y_decoder(z_pad, y_mask, xz_states, x_mask)
logits = self.expander_nn(decoder_states)
likelihood = - F.cross_entropy(logits[0], y[0], reduction="sum")
likelihood_list.append(likelihood)
kl = self.compute_vae_KL(xz_prob, yz_prob).sum()
mean_likelihood = sum(likelihood_list) / len(likelihood_list)
elbo = mean_likelihood - kl
return elbo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment