Created
November 20, 2019 14:38
-
-
Save sergei-mironov/56ca9d220ad0b0fc8d8a10a5d437d400 to your computer and use it in GitHub Desktop.
TF Official bert customization, old design
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import copy | |
import json | |
import math | |
import six | |
import tensorflow as tf | |
import itertools | |
from tensorflow.keras.losses import MeanSquaredError | |
from official.modeling import tf_utils | |
from official.nlp.bert_modeling import BertModel, Attention, Dense3D, Dense2DProjection, get_initializer | |
from ipdb import set_trace | |
from typing import Any | |
class CustomTransformerBlock(tf.keras.layers.Layer): | |
mimic_original_weights:bool=False | |
def __init__(self, original, bypass_output_ratio, **kwargs): | |
super().__init__(**kwargs) | |
self.original=original | |
self.bypass_output_ratio = bypass_output_ratio | |
self.build([]) | |
def build(self, unused_input): | |
tb=self.original | |
self.original.build(unused_input) | |
def _init(): | |
def _call(*args, **kwargs): | |
return get_initializer(tb.initializer_range)(*args, **kwargs) | |
return _call | |
self.bypass_dense = tf.keras.layers.Dense(tb.hidden_size, | |
kernel_initializer=_init(), name="bypass") | |
self.distill_loss = MeanSquaredError() | |
super().build(unused_input) | |
def __call__(self, input_tensor, attention_mask=None): | |
inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) | |
return super().__call__(inputs) | |
def call(self, inputs): | |
tb = self.original | |
assert tb.float_type != tf.float16 | |
(input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs) | |
attention_output = tb.attention_layer( | |
from_tensor=input_tensor, | |
to_tensor=input_tensor, | |
attention_mask=attention_mask) | |
attention_output = tb.attention_output_dense(attention_output) | |
attention_output = tb.attention_dropout(attention_output) | |
attention_output = tb.attention_layer_norm(input_tensor + attention_output) | |
intermediate_output = tb.intermediate_dense(attention_output) | |
dense_output = tb.output_dense(intermediate_output) | |
bypass_output = self.bypass_dense(attention_output) | |
self.add_loss(self.distill_loss(bypass_output, dense_output)) | |
switch = self.bypass_output_ratio | |
layer_output = tb.output_dropout((1.0-switch)*dense_output + switch*bypass_output) | |
layer_output = tb.output_layer_norm(layer_output + attention_output) | |
return layer_output | |
@property | |
def trainable_weights(self): | |
if CustomTransformerBlock.mimic_original_weights: | |
return self.original.trainable_weights | |
else: | |
return super().trainable_weights | |
def distillation_mode(self, enabled:bool=True)->None: | |
if enabled: | |
self.original.trainable=False | |
self.bypass_dense.trainable=True | |
self._trainable=True | |
else: | |
self.trainable=True | |
def get_bert_model_(m:Any): | |
if isinstance(m,BertModel): | |
return m | |
for l in m.layers: | |
if isinstance(l,BertModel): | |
return l | |
else: | |
print('ignoring', type(l)) | |
raise ValueError("Keras model doesn't contain a layer of type 'BertModel'") | |
def export_official_bert_weights(f:tf.keras.Model, t:tf.keras.Model): | |
""" Export weights from official bert model to the custom model. New weights which are not | |
present in official bert left uninitialized """ | |
try: | |
CustomTransformerBlock.mimic_original_weights = True | |
t.set_weights(f.get_weights()) | |
finally: | |
CustomTransformerBlock.mimic_original_weights = False | |
def patch_official_bert(m:tf.keras.Model, bypass_output_ratio:float)->None: | |
""" Replaces stock TransformerBlocks with CustomTransformerBlocks """ | |
bert_encoder=get_bert_model_(m).encoder | |
new_layers=[] | |
for i,tb in enumerate(bert_encoder.layers): | |
new_layers.append(CustomTransformerBlock( | |
tb, | |
bypass_output_ratio, | |
name=f"custom_layer_{i}")) | |
bert_encoder.layers = new_layers | |
print('Patched', len(new_layers), 'layers of', len(bert_encoder.layers)) | |
assert len(new_layers)>0 | |
def is_custom_transformer(l:Any)->bool: | |
return hasattr(l,'original') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment