Created
August 30, 2019 20:05
-
-
Save zomux/1260d788261b976c89c5e0e9987a2b2f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from nmtlab.models.transformer import Transformer | |
from nmtlab.modules.transformer_modules import TransformerEmbedding | |
from nmtlab.modules.transformer_modules import TemporalMasking | |
from nmtlab.modules.transformer_modules import TransformerEncoderLayer | |
from nmtlab.modules.transformer_modules import PositionalEmbedding | |
from stnmt.lib_stcode_modules import STCodeDecoderLayer, STCodeEncoderLayer | |
from nmtlab.modules import MultiHeadAttention | |
from stnmt.semhash import SemanticHashing | |
from stnmt.gumbel import GumbelEncoder | |
from nmtlab.utils import OPTS | |
from nmtlab.utils import MapDict | |
class STCodeModel(Transformer): | |
def __init__(self, num_encoders=3, num_decoders=3, ff_size=None, n_att_heads=2, dropout_ratio=0.1, **kwargs): | |
self.num_encoders = num_encoders | |
self.num_decoders = num_decoders | |
self._ff_size = ff_size | |
self._n_att_heads = n_att_heads | |
self._dropout_ratio = dropout_ratio | |
super(Transformer, self).__init__(**kwargs) | |
if ff_size is None: | |
self._ff_size = self.hidden_size * 4 | |
def prepare(self): | |
from nmtlab.modules.transformer_modules import LabelSmoothingKLDivLoss | |
self.label_smooth = LabelSmoothingKLDivLoss(0.1, self._tgt_vocab_size, 0) | |
if OPTS.gumbel: | |
self.gumbel_encoder = GumbelEncoder(self.hidden_size, bits=OPTS.codebits) | |
else: | |
self.semhash = SemanticHashing(self.hidden_size, bits=OPTS.codebits) | |
# Layer Norm | |
self.encoder_norm = nn.LayerNorm(self.hidden_size) | |
self.decoder_norm = nn.LayerNorm(self.hidden_size) | |
self.tag_norm = nn.LayerNorm(self.hidden_size) | |
# Shared embedding layer | |
self.src_embed_layer = TransformerEmbedding(self._src_vocab_size, self.embed_size) | |
self.tgt_embed_layer = TransformerEmbedding(self._tgt_vocab_size, self.embed_size) | |
self.temporal_mask = TemporalMasking() | |
# Encoder | |
self.src_encode_layers = nn.ModuleList() | |
self.tag_encode_layers = nn.ModuleList() | |
if OPTS.linearcode: | |
self.code_linear = nn.Linear(self.hidden_size, self.hidden_size * OPTS.codelen) | |
else: | |
self.positional_encoding = PositionalEmbedding(self.hidden_size) | |
self.code_attention = MultiHeadAttention(self.hidden_size, self._n_att_heads, dropout_ratio=self._dropout_ratio) | |
for _ in range(self.num_encoders): | |
layer = TransformerEncoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads, | |
dropout_ratio=self._dropout_ratio) | |
self.src_encode_layers.append(layer) | |
layer = STCodeEncoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads, | |
dropout_ratio=self._dropout_ratio) | |
self.tag_encode_layers.append(layer) | |
# Decoder | |
self.decoder_layers = nn.ModuleList() | |
for _ in range(self.num_decoders): | |
layer = STCodeDecoderLayer(self.hidden_size, self._ff_size, n_att_head=self._n_att_heads, | |
dropout_ratio=self._dropout_ratio) | |
self.decoder_layers.append(layer) | |
# Expander | |
self.expander_nn = nn.Linear(self.hidden_size, self._tgt_vocab_size) | |
# Decoding states need to be remembered for beam search | |
state_names = ["embeddings"] | |
for i in range(self.num_decoders): | |
state_names.append("layer{}".format(i)) | |
self.set_states(state_names) | |
self.set_stepwise_training(False) | |
def encode(self, src_seq, src_mask=None): | |
x = self.src_embed_layer(src_seq) | |
for l, layer in enumerate(self.src_encode_layers): | |
x = layer(x, src_mask) | |
encoder_states = self.encoder_norm(x) | |
encoder_outputs = { | |
"encoder_states": encoder_states, | |
"src_mask": src_mask | |
} | |
return encoder_outputs | |
def encode_codes(self, tgt_seq, tgt_mask, encoder_states, src_mask, return_code=False): | |
x = self.tgt_embed_layer(tgt_seq) | |
for l, layer in enumerate(self.tag_encode_layers): | |
x = layer(x, encoder_states, src_mask, tgt_mask) | |
tag_states = self.tag_norm(x) | |
# Semantic Hashing | |
B = tag_states.shape[0] | |
if OPTS.linearcode: | |
avg_states = tag_states.mean(1) # B x H | |
precode_states = self.code_linear(avg_states).view(B, OPTS.codelen, self.hidden_size) | |
else: | |
pos_vectors = tag_states.new_zeros((B, OPTS.codelen, self.hidden_size)) | |
pos_vectors = self.positional_encoding(pos_vectors) | |
precode_states, _ = self.code_attention(pos_vectors, tag_states, tag_states, mask=tgt_mask) | |
if OPTS.nocode or (OPTS.pretrain and OPTS.trainer.epoch() <= 1): | |
codes_states = precode_states | |
else: | |
if OPTS.gumbel: | |
codes_states = self.gumbel_encoder(precode_states, return_code=return_code) | |
self.monitor("pmax", self.gumbel_encoder.max_gumbel_probs().mean()) | |
else: | |
codes_states = self.semhash(precode_states, return_code=return_code) | |
return codes_states | |
def decode_step(self, context, states, full_sequence=False): | |
if full_sequence: | |
# During training: full sequence mode | |
x = states.feedback_embed[:, :-1] | |
temporal_mask = self.temporal_mask(x) | |
# print("full embed", x[1, :, :2]) | |
for l, layer in enumerate(self.decoder_layers): | |
x = layer(context.encoder_states, context.code_states, x, context.src_mask, temporal_mask) | |
# print("full {}".format(l), x[1, :, :2]) | |
states["final_states"] = self.decoder_norm(x) | |
else: | |
# During beam search: stepwise mode | |
feedback_embed = self.tgt_embed_layer(states.prev_token.transpose(0, 1), start=states.t).transpose(0, 1) # ~ (batch, size) | |
# print("embed", feedback_embed[0, 1, :2]) | |
if states.t == 0: | |
states.embeddings = feedback_embed | |
else: | |
states.embeddings = torch.cat([states.embeddings, feedback_embed], 0) | |
x = states.embeddings.transpose(0, 1) | |
for l, layer in enumerate(self.decoder_layers): | |
x = layer(context.encoder_states, x, last_only=True) # ~ (batch, 1, size) | |
if states.t == 0: | |
states["layer{}".format(l)] = x.transpose(0, 1) | |
else: | |
old_states = states["layer{}".format(l)] | |
states["layer{}".format(l)] = torch.cat([old_states, x.transpose(0, 1)], 0) | |
x = states["layer{}".format(l)].transpose(0, 1) | |
states["final_states"] = self.decoder_norm(x[:, -1].unsqueeze(0)) # ~ (1, batch ,size) | |
def forward(self, src_seq, tgt_seq, sampling=False, return_code=False): | |
"""Forward to compute the loss. | |
""" | |
sampling = False | |
src_mask = torch.ne(src_seq, 0).float() | |
tgt_mask = torch.ne(tgt_seq, 0).float() | |
encoder_outputs = MapDict(self.encode(src_seq, src_mask)) | |
code_states = self.encode_codes(tgt_seq, tgt_mask, encoder_outputs.encoder_states, src_mask, return_code=return_code) | |
if return_code: | |
return code_states | |
context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) | |
context.code_states = code_states | |
decoder_outputs = self.decode(context, states) | |
if self._shard_size is not None and self._shard_size > 0: | |
self.compute_shard_loss(decoder_outputs, tgt_seq, tgt_mask) | |
else: | |
logits = self.expand(decoder_outputs) | |
loss = self.compute_loss(logits, tgt_seq, tgt_mask) | |
acc = self.compute_word_accuracy(logits, tgt_seq, tgt_mask) | |
self.monitor("loss", loss) | |
self.monitor("word_acc", acc) | |
if sampling: | |
context, states = self.pre_decode(encoder_outputs, tgt_seq, src_mask=src_mask, tgt_mask=tgt_mask) | |
sample_outputs = self.decode(context, states, sampling=True) | |
self.monitor("sampled_tokens", sample_outputs.prev_token) | |
return self._monitors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment