Skip to content

Instantly share code, notes, and snippets.

@sergei-mironov
Created November 20, 2019 14:38
Show Gist options
  • Save sergei-mironov/56ca9d220ad0b0fc8d8a10a5d437d400 to your computer and use it in GitHub Desktop.
Save sergei-mironov/56ca9d220ad0b0fc8d8a10a5d437d400 to your computer and use it in GitHub Desktop.
TF Official bert customization, old design
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import json
import math
import six
import tensorflow as tf
import itertools
from tensorflow.keras.losses import MeanSquaredError
from official.modeling import tf_utils
from official.nlp.bert_modeling import BertModel, Attention, Dense3D, Dense2DProjection, get_initializer
from ipdb import set_trace
from typing import Any
class CustomTransformerBlock(tf.keras.layers.Layer):
mimic_original_weights:bool=False
def __init__(self, original, bypass_output_ratio, **kwargs):
super().__init__(**kwargs)
self.original=original
self.bypass_output_ratio = bypass_output_ratio
self.build([])
def build(self, unused_input):
tb=self.original
self.original.build(unused_input)
def _init():
def _call(*args, **kwargs):
return get_initializer(tb.initializer_range)(*args, **kwargs)
return _call
self.bypass_dense = tf.keras.layers.Dense(tb.hidden_size,
kernel_initializer=_init(), name="bypass")
self.distill_loss = MeanSquaredError()
super().build(unused_input)
def __call__(self, input_tensor, attention_mask=None):
inputs = tf_utils.pack_inputs([input_tensor, attention_mask])
return super().__call__(inputs)
def call(self, inputs):
tb = self.original
assert tb.float_type != tf.float16
(input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs)
attention_output = tb.attention_layer(
from_tensor=input_tensor,
to_tensor=input_tensor,
attention_mask=attention_mask)
attention_output = tb.attention_output_dense(attention_output)
attention_output = tb.attention_dropout(attention_output)
attention_output = tb.attention_layer_norm(input_tensor + attention_output)
intermediate_output = tb.intermediate_dense(attention_output)
dense_output = tb.output_dense(intermediate_output)
bypass_output = self.bypass_dense(attention_output)
self.add_loss(self.distill_loss(bypass_output, dense_output))
switch = self.bypass_output_ratio
layer_output = tb.output_dropout((1.0-switch)*dense_output + switch*bypass_output)
layer_output = tb.output_layer_norm(layer_output + attention_output)
return layer_output
@property
def trainable_weights(self):
if CustomTransformerBlock.mimic_original_weights:
return self.original.trainable_weights
else:
return super().trainable_weights
def distillation_mode(self, enabled:bool=True)->None:
if enabled:
self.original.trainable=False
self.bypass_dense.trainable=True
self._trainable=True
else:
self.trainable=True
def get_bert_model_(m:Any):
if isinstance(m,BertModel):
return m
for l in m.layers:
if isinstance(l,BertModel):
return l
else:
print('ignoring', type(l))
raise ValueError("Keras model doesn't contain a layer of type 'BertModel'")
def export_official_bert_weights(f:tf.keras.Model, t:tf.keras.Model):
""" Export weights from official bert model to the custom model. New weights which are not
present in official bert left uninitialized """
try:
CustomTransformerBlock.mimic_original_weights = True
t.set_weights(f.get_weights())
finally:
CustomTransformerBlock.mimic_original_weights = False
def patch_official_bert(m:tf.keras.Model, bypass_output_ratio:float)->None:
""" Replaces stock TransformerBlocks with CustomTransformerBlocks """
bert_encoder=get_bert_model_(m).encoder
new_layers=[]
for i,tb in enumerate(bert_encoder.layers):
new_layers.append(CustomTransformerBlock(
tb,
bypass_output_ratio,
name=f"custom_layer_{i}"))
bert_encoder.layers = new_layers
print('Patched', len(new_layers), 'layers of', len(bert_encoder.layers))
assert len(new_layers)>0
def is_custom_transformer(l:Any)->bool:
return hasattr(l,'original')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment