Last active
November 21, 2019 08:08
-
-
Save sergei-mironov/bba6080d950919c0a6c3bdfaf25176fe to your computer and use it in GitHub Desktop.
TF Official bert customization, new design
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import copy | |
import json | |
import math | |
import six | |
import itertools | |
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import | |
from official.modeling import activations | |
from official.nlp import bert_modeling | |
from official.nlp.modeling import layers | |
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder,MakeAttentionMaskLayer | |
from official.nlp.modeling.networks.bert_classifier import BertClassifier | |
from official.modeling import tf_utils | |
from official.nlp.bert_modeling import get_initializer | |
from ipdb import set_trace | |
from typing import Any | |
from tensorflow.keras.losses import MeanSquaredError | |
class CustomTransformer(tf.keras.layers.Layer): | |
mimic_original_weights:bool=False | |
def __init__(self, original, bypass_output_ratio, **kwargs): | |
super().__init__(**kwargs) | |
self.original=original | |
self.bypass_output_ratio = bypass_output_ratio | |
def build(self, input_shape): | |
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape | |
input_tensor_shape = tf.TensorShape(input_tensor) | |
if len(input_tensor_shape) != 3: | |
raise ValueError("CustomTransformer expects a three-dimensional input of " | |
"shape [batch, sequence, width].") | |
batch_size, sequence_length, hidden_size = input_tensor_shape | |
tb=self.original | |
self.original.build(input_shape) | |
self.bypass_dense = tf.keras.layers.Dense(hidden_size, | |
kernel_initializer=tb._kernel_initializer, name="bypass") | |
self.distill_loss = MeanSquaredError() | |
super().build(input_shape) | |
def call(self, inputs): | |
orig = self.original | |
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: | |
input_tensor, attention_mask = inputs | |
else: | |
input_tensor, attention_mask = (inputs, None) | |
attention_inputs = [input_tensor, input_tensor] | |
if attention_mask is not None: | |
attention_inputs.append(attention_mask) | |
attention_output = orig._attention_layer(attention_inputs) | |
attention_output = orig._attention_output_dense(attention_output) | |
attention_output = orig._attention_dropout(attention_output) | |
# Use float32 in keras layer norm and the gelu activation in the | |
# intermediate dense layer for numeric stability | |
attention_output = orig._attention_layer_norm(input_tensor + | |
attention_output) | |
intermediate_output = orig._intermediate_dense(attention_output) | |
dense_output = orig._output_dense(intermediate_output) | |
bypass_output = self.bypass_dense(attention_output) | |
# FIXME!! self.add_loss(self.distill_loss(bypass_output, dense_output)) | |
switch = self.bypass_output_ratio | |
layer_output = orig._output_dropout((1.0-switch)*dense_output + switch*bypass_output) | |
layer_output = orig._output_layer_norm(layer_output + attention_output) | |
return layer_output | |
@property | |
def trainable_weights(self): | |
if CustomTransformer.mimic_original_weights: | |
return self.original.trainable_weights | |
else: | |
return super().trainable_weights | |
def distillation_mode(self, enabled:bool=True)->None: | |
if enabled: | |
self.original.trainable=False | |
self.bypass_dense.trainable=True | |
self._trainable=True | |
else: | |
self.trainable=True | |
class CustomTransformerEncoder(network.Network):# {{{ | |
def __init__(self, | |
vocab_size, | |
bypass_output_ratio, | |
hidden_size=768, | |
num_layers=12, | |
num_attention_heads=12, | |
sequence_length=512, | |
max_sequence_length=None, | |
type_vocab_size=16, | |
intermediate_size=3072, | |
activation=activations.gelu, | |
dropout_rate=0.1, | |
attention_dropout_rate=0.1, | |
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), | |
float_dtype='float32', | |
**kwargs): | |
activation = tf.keras.activations.get(activation) | |
initializer = tf.keras.initializers.get(initializer) | |
if not max_sequence_length: | |
max_sequence_length = sequence_length | |
self._self_setattr_tracking = False | |
self._config_dict = { | |
'vocab_size': vocab_size, | |
'hidden_size': hidden_size, | |
'num_layers': num_layers, | |
'num_attention_heads': num_attention_heads, | |
'sequence_length': sequence_length, | |
'max_sequence_length': max_sequence_length, | |
'type_vocab_size': type_vocab_size, | |
'intermediate_size': intermediate_size, | |
'activation': tf.keras.activations.serialize(activation), | |
'dropout_rate': dropout_rate, | |
'attention_dropout_rate': attention_dropout_rate, | |
'initializer': tf.keras.initializers.serialize(initializer), | |
'float_dtype': float_dtype, | |
} | |
word_ids = tf.keras.layers.Input( | |
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids') | |
mask = tf.keras.layers.Input( | |
shape=(sequence_length,), dtype=tf.int32, name='input_mask') | |
type_ids = tf.keras.layers.Input( | |
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids') | |
self._embedding_layer = layers.OnDeviceEmbedding( | |
vocab_size=vocab_size, | |
embedding_width=hidden_size, | |
initializer=initializer, | |
name='word_embeddings') | |
word_embeddings = self._embedding_layer(word_ids) | |
# Always uses dynamic slicing for simplicity. | |
self._position_embedding_layer = layers.PositionEmbedding( | |
initializer=initializer, | |
use_dynamic_slicing=True, | |
max_sequence_length=max_sequence_length) | |
position_embeddings = self._position_embedding_layer(word_embeddings) | |
type_embeddings = ( | |
layers.OnDeviceEmbedding( | |
vocab_size=type_vocab_size, | |
embedding_width=hidden_size, | |
initializer=initializer, | |
use_one_hot=True, | |
name='type_embeddings')(type_ids)) | |
embeddings = tf.keras.layers.Add()( | |
[word_embeddings, position_embeddings, type_embeddings]) | |
embeddings = ( | |
tf.keras.layers.LayerNormalization( | |
name='embeddings/layer_norm', | |
axis=-1, | |
epsilon=1e-12, | |
dtype=tf.float32)(embeddings)) | |
embeddings = ( | |
tf.keras.layers.Dropout(rate=dropout_rate, | |
dtype=tf.float32)(embeddings)) | |
if float_dtype == 'float16': | |
embeddings = tf.cast(embeddings, tf.float16) | |
data = embeddings | |
attention_mask = MakeAttentionMaskLayer()([data, mask]) | |
for i in range(num_layers): | |
layer = CustomTransformer(layers.Transformer( | |
num_attention_heads=num_attention_heads, | |
intermediate_size=intermediate_size, | |
intermediate_activation=activation, | |
dropout_rate=dropout_rate, | |
attention_dropout_rate=attention_dropout_rate, | |
kernel_initializer=initializer, | |
dtype=float_dtype, | |
name='transformer/layer_%d' % i), bypass_output_ratio) | |
data = layer([data, attention_mask]) | |
first_token_tensor = ( | |
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data) | |
) | |
cls_output = tf.keras.layers.Dense( | |
units=hidden_size, | |
activation='tanh', | |
kernel_initializer=initializer, | |
name='pooler_transform')( | |
first_token_tensor) | |
super().__init__( | |
inputs=[word_ids, mask, type_ids], | |
outputs=[data, cls_output], | |
**kwargs)# }}} | |
def custom_classifier_model(bert_config, float_type, num_labels, max_seq_length, bypass_output_ratio): | |
initializer = tf.keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range) | |
bert_encoder = CustomTransformerEncoder( | |
vocab_size=bert_config.vocab_size, | |
bypass_output_ratio=bypass_output_ratio, | |
hidden_size=bert_config.hidden_size, | |
num_layers=bert_config.num_hidden_layers, | |
num_attention_heads=bert_config.num_attention_heads, | |
intermediate_size=bert_config.intermediate_size, | |
activation=tf_utils.get_activation('gelu'), | |
dropout_rate=bert_config.hidden_dropout_prob, | |
attention_dropout_rate=bert_config.attention_probs_dropout_prob, | |
sequence_length=max_seq_length, | |
max_sequence_length=bert_config.max_position_embeddings, | |
type_vocab_size=bert_config.type_vocab_size, | |
initializer=tf.keras.initializers.TruncatedNormal( | |
stddev=bert_config.initializer_range), | |
float_dtype=tf.float32) | |
return BertClassifier( | |
bert_encoder, | |
num_classes=num_labels, | |
dropout_rate=bert_config.hidden_dropout_prob, | |
initializer=initializer), bert_encoder | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment