Skip to content

Instantly share code, notes, and snippets.

@zomux
Created August 30, 2019 20:05
Show Gist options
  • Save zomux/1260d788261b976c89c5e0e9987a2b2f to your computer and use it in GitHub Desktop.
Save zomux/1260d788261b976c89c5e0e9987a2b2f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from nmtlab.models.transformer import Transformer
from nmtlab.modules.transformer_modules import TransformerEmbedding
from nmtlab.modules.transformer_modules import TemporalMasking
from nmtlab.modules.transformer_modules import TransformerEncoderLayer
from nmtlab.modules.transformer_modules import PositionalEmbedding
from stnmt.lib_stcode_modules import STCodeDecoderLayer, STCodeEncoderLayer
from nmtlab.modules import MultiHeadAttention
from stnmt.semhash import SemanticHashing
from stnmt.gumbel import GumbelEncoder
from nmtlab.utils import OPTS
from nmtlab.utils import MapDict
class STCodeModel(Transformer):
def __init__(self, num_encoders=3, num_decoders=3, ff_size=None, n_att_heads=2, dropout_ratio=0.1, **kwargs):
self.num_encoders = num_encoders
self.num_decoders = num_decoders
self._ff_size = ff_size
self._n_att_heads = n_att_heads
self._dropout_ratio = dropout_ratio
super(Transformer, self).__init__(**kwargs)
if ff_size is None:
self._ff_size = self.hidden_size * 4
def prepare(self):
from nmtlab.modules.transformer_modules import LabelSmoothingKLDivLoss
self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size, 0)
if OPTS.gumbel:
self.gumbel_encoder = GumbelEncoder(self.hidden_size, bits=OPTS.codebits)
else:
self.semhash = SemanticHashing(self.hidden_size, bits=OPTS.codebits)
# Layer Norm
self.encoder_norm = nn.LayerNorm(self.hidden_size)
self.decoder_norm = nn.LayerNorm(self.hidden_size)
self.tag_norm = nn.LayerNorm(self.hidden_size)
# Shared embedding layer
self.src_embed_layer = TransformerEmbedding(self._src_vocab_size, self.embed_size)
self.tgt_embed_layer = TransformerEmbedding(self._tgt_vocab_size, self.embed_size)
self.temporal_mask = TemporalMasking()
# Encoder
self.src_encode_layers = nn.ModuleList()
self.tag_encode_layers = nn.ModuleList()
if OPTS.linearcode:
self.code_linear = nn.Linear(self.hidden_size, self.hidden_size * OPTS.codelen)
else:
self.positional_encoding = PositionalEmbedding(self.hidden_size)
self.code_attention = MultiHeadAttention(self.hidden_size, self._n_att_heads, dropout_ratio=self._dropout_ratio)
for _ in range(self.num_encoders):
layer = TransformerEncoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads,
dropout_ratio=self._dropout_ratio)
self.src_encode_layers.append(layer)
layer = STCodeEncoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads,
dropout_ratio=self._dropout_ratio)
self.tag_encode_layers.append(layer)
# Decoder
self.decoder_layers = nn.ModuleList()
for _ in range(self.num_decoders):
layer = STCodeDecoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads,
dropout_ratio=self._dropout_ratio)
self.decoder_layers.append(layer)
# Expander
self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size)
# Decoding states need to be remembered for beam search
state_names = ["embeddings"]
for i in range(self.num_decoders):
state_names.append("layer{}".format(i))
self.set_states(state_names)
self.set_stepwise_training(False)
def encode(self, src_seq, src_mask=None):
x = self.src_embed_layer(src_seq)
for l, layer in enumerate(self.src_encode_layers):
x = layer(x, src_mask)
encoder_states = self.encoder_norm(x)
encoder_outputs = {
"encoder_states": encoder_states,
"src_mask": src_mask
}
return encoder_outputs
def encode_codes(self, tgt_seq, tgt_mask, encoder_states, src_mask, return_code=False):
x = self.tgt_embed_layer(tgt_seq)
for l, layer in enumerate(self.tag_encode_layers):
x = layer(x, encoder_states, src_mask, tgt_mask)
tag_states = self.tag_norm(x)
# Semantic Hashing
B = tag_states.shape[0]
if OPTS.linearcode:
avg_states = tag_states.mean(1) # B x H
precode_states = self.code_linear(avg_states).view(B, OPTS.codelen, self.hidden_size)
else:
pos_vectors = tag_states.new_zeros((B, OPTS.codelen, self.hidden_size))
pos_vectors = self.positional_encoding(pos_vectors)
precode_states, _ = self.code_attention(pos_vectors, tag_states, tag_states, mask=tgt_mask)
if OPTS.nocode or (OPTS.pretrain and OPTS.trainer.epoch() <= 1):
codes_states = precode_states
else:
if OPTS.gumbel:
codes_states = self.gumbel_encoder(precode_states, return_code=return_code)
self.monitor("pmax", self.gumbel_encoder.max_gumbel_probs().mean())
else:
codes_states = self.semhash(precode_states, return_code=return_code)
return codes_states
def decode_step(self, context, states, full_sequence=False):
if full_sequence:
# During training: full sequence mode
x = states.feedback_embed[:, :-1]
temporal_mask = self.temporal_mask(x)
# print("full embed", x[1, :, :2])
for l, layer in enumerate(self.decoder_layers):
x = layer(context.encoder_states, context.code_states, x, context.src_mask, temporal_mask)
# print("full {}".format(l), x[1, :, :2])
states["final_states"] = self.decoder_norm(x)
else:
# During beam search: stepwise mode
feedback_embed = self.tgt_embed_layer(states.prev_token.transpose(0, 1), start=states.t).transpose(0, 1) # ~ (batch, size)
# print("embed", feedback_embed[0, 1, :2])
if states.t == 0:
states.embeddings = feedback_embed
else:
states.embeddings = torch.cat([states.embeddings, feedback_embed], 0)
x = states.embeddings.transpose(0, 1)
for l, layer in enumerate(self.decoder_layers):
x = layer(context.encoder_states, x, last_only=True) # ~ (batch, 1, size)
if states.t == 0:
states["layer{}".format(l)] = x.transpose(0, 1)
else:
old_states = states["layer{}".format(l)]
states["layer{}".format(l)] = torch.cat([old_states, x.transpose(0, 1)], 0)
x = states["layer{}".format(l)].transpose(0, 1)
states["final_states"] = self.decoder_norm(x[:, -1].unsqueeze(0)) # ~ (1, batch ,size)
def forward(self, src_seq, tgt_seq, sampling=False, return_code=False):
"""Forward to compute the loss.
"""
sampling = False
src_mask = torch.ne(src_seq, 0).float()
tgt_mask = torch.ne(tgt_seq, 0).float()
encoder_outputs = MapDict(self.encode(src_seq, src_mask))
code_states = self.encode_codes(tgt_seq, tgt_mask, encoder_outputs.encoder_states, src_mask, return_code=return_code)
if return_code:
return code_states
context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask)
context.code_states = code_states
decoder_outputs = self.decode(context, states)
if self._shard_size is not None and self._shard_size > 0:
self.compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask)
else:
logits = self.expand(decoder_outputs)
loss = self.compute_loss(logits, tgt_seq, tgt_mask)
acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask)
self.monitor("loss", loss)
self.monitor("word_acc", acc)
if sampling:
context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask)
sample_outputs = self.decode(context, states, sampling=True)
self.monitor("sampled_tokens", sample_outputs.prev_token)
return self._monitors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment