Skip to content

Instantly share code, notes, and snippets.

@sergei-mironov
Last active November 21, 2019 08:08
Show Gist options
  • Save sergei-mironov/bba6080d950919c0a6c3bdfaf25176fe to your computer and use it in GitHub Desktop.
Save sergei-mironov/bba6080d950919c0a6c3bdfaf25176fe to your computer and use it in GitHub Desktop.
TF Official bert customization, new design
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import copy
import json
import math
import six
import itertools
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp import bert_modeling
from official.nlp.modeling import layers
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder,MakeAttentionMaskLayer
from official.nlp.modeling.networks.bert_classifier import BertClassifier
from official.modeling import tf_utils
from official.nlp.bert_modeling import get_initializer
from ipdb import set_trace
from typing import Any
from tensorflow.keras.losses import MeanSquaredError
class CustomTransformer(tf.keras.layers.Layer):
mimic_original_weights:bool=False
def __init__(self, original, bypass_output_ratio, **kwargs):
super().__init__(**kwargs)
self.original=original
self.bypass_output_ratio = bypass_output_ratio
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
input_tensor_shape = tf.TensorShape(input_tensor)
if len(input_tensor_shape) != 3:
raise ValueError("CustomTransformer expects a three-dimensional input of "
"shape [batch, sequence, width].")
batch_size, sequence_length, hidden_size = input_tensor_shape
tb=self.original
self.original.build(input_shape)
self.bypass_dense = tf.keras.layers.Dense(hidden_size,
kernel_initializer=tb._kernel_initializer, name="bypass")
self.distill_loss = MeanSquaredError()
super().build(input_shape)
def call(self, inputs):
orig = self.original
if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
if attention_mask is not None:
attention_inputs.append(attention_mask)
attention_output = orig._attention_layer(attention_inputs)
attention_output = orig._attention_output_dense(attention_output)
attention_output = orig._attention_dropout(attention_output)
# Use float32 in keras layer norm and the gelu activation in the
# intermediate dense layer for numeric stability
attention_output = orig._attention_layer_norm(input_tensor +
attention_output)
intermediate_output = orig._intermediate_dense(attention_output)
dense_output = orig._output_dense(intermediate_output)
bypass_output = self.bypass_dense(attention_output)
# FIXME!! self.add_loss(self.distill_loss(bypass_output, dense_output))
switch = self.bypass_output_ratio
layer_output = orig._output_dropout((1.0-switch)*dense_output + switch*bypass_output)
layer_output = orig._output_layer_norm(layer_output + attention_output)
return layer_output
@property
def trainable_weights(self):
if CustomTransformer.mimic_original_weights:
return self.original.trainable_weights
else:
return super().trainable_weights
def distillation_mode(self, enabled:bool=True)->None:
if enabled:
self.original.trainable=False
self.bypass_dense.trainable=True
self._trainable=True
else:
self.trainable=True
class CustomTransformerEncoder(network.Network):# {{{
def __init__(self,
vocab_size,
bypass_output_ratio,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
sequence_length=512,
max_sequence_length=None,
type_vocab_size=16,
intermediate_size=3072,
activation=activations.gelu,
dropout_rate=0.1,
attention_dropout_rate=0.1,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
float_dtype='float32',
**kwargs):
activation = tf.keras.activations.get(activation)
initializer = tf.keras.initializers.get(initializer)
if not max_sequence_length:
max_sequence_length = sequence_length
self._self_setattr_tracking = False
self._config_dict = {
'vocab_size': vocab_size,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_attention_heads': num_attention_heads,
'sequence_length': sequence_length,
'max_sequence_length': max_sequence_length,
'type_vocab_size': type_vocab_size,
'intermediate_size': intermediate_size,
'activation': tf.keras.activations.serialize(activation),
'dropout_rate': dropout_rate,
'attention_dropout_rate': attention_dropout_rate,
'initializer': tf.keras.initializers.serialize(initializer),
'float_dtype': float_dtype,
}
word_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_word_ids')
mask = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_mask')
type_ids = tf.keras.layers.Input(
shape=(sequence_length,), dtype=tf.int32, name='input_type_ids')
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=hidden_size,
initializer=initializer,
name='word_embeddings')
word_embeddings = self._embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
use_dynamic_slicing=True,
max_sequence_length=max_sequence_length)
position_embeddings = self._position_embedding_layer(word_embeddings)
type_embeddings = (
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=hidden_size,
initializer=initializer,
use_one_hot=True,
name='type_embeddings')(type_ids))
embeddings = tf.keras.layers.Add()(
[word_embeddings, position_embeddings, type_embeddings])
embeddings = (
tf.keras.layers.LayerNormalization(
name='embeddings/layer_norm',
axis=-1,
epsilon=1e-12,
dtype=tf.float32)(embeddings))
embeddings = (
tf.keras.layers.Dropout(rate=dropout_rate,
dtype=tf.float32)(embeddings))
if float_dtype == 'float16':
embeddings = tf.cast(embeddings, tf.float16)
data = embeddings
attention_mask = MakeAttentionMaskLayer()([data, mask])
for i in range(num_layers):
layer = CustomTransformer(layers.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
intermediate_activation=activation,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_initializer=initializer,
dtype=float_dtype,
name='transformer/layer_%d' % i), bypass_output_ratio)
data = layer([data, attention_mask])
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
)
cls_output = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
name='pooler_transform')(
first_token_tensor)
super().__init__(
inputs=[word_ids, mask, type_ids],
outputs=[data, cls_output],
**kwargs)# }}}
def custom_classifier_model(bert_config, float_type, num_labels, max_seq_length, bypass_output_ratio):
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
bert_encoder = CustomTransformerEncoder(
vocab_size=bert_config.vocab_size,
bypass_output_ratio=bypass_output_ratio,
hidden_size=bert_config.hidden_size,
num_layers=bert_config.num_hidden_layers,
num_attention_heads=bert_config.num_attention_heads,
intermediate_size=bert_config.intermediate_size,
activation=tf_utils.get_activation('gelu'),
dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
sequence_length=max_seq_length,
max_sequence_length=bert_config.max_position_embeddings,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
float_dtype=tf.float32)
return BertClassifier(
bert_encoder,
num_classes=num_labels,
dropout_rate=bert_config.hidden_dropout_prob,
initializer=initializer), bert_encoder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment