Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active November 17, 2019 17:45
Show Gist options
  • Select an option

  • Save gaphex/5ab204c3d2f9a448ea919eb497f9bc10 to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/5ab204c3d2f9a448ea919eb497f9bc10 to your computer and use it in GitHub Desktop.
Keras BERT layer
class BertLayer(tf.keras.layers.Layer):
def __init__(self, bert_path, seq_len=64, n_tune_layers=3,
pooling="cls", verbose=False,
tune_embeddings=False, **kwargs):
self.n_tune_layers = n_tune_layers
self.tune_embeddings = tune_embeddings
self.seq_len = seq_len
self.trainable = True
self.verbose = verbose
self.pooling = pooling
self.bert_path = bert_path
self.var_per_encoder = 16
if self.pooling not in ["cls", "mean", None]:
raise NameError(
f"Undefined pooling type (must be either 'cls', 'mean', or None, but is {self.pooling}"
)
super(BertLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.bert = hub.Module(self.bert_path, trainable=self.trainable, name=f"{self.name}_module")
trainable_layers = []
if self.tune_embeddings:
trainable_layers.append("embeddings")
if self.pooling == "cls":
trainable_layers.append("pooler")
if self.n_tune_layers > 0:
encoder_var_names = [var.name for var in self.bert.variables if 'encoder' in var.name]
n_encoder_layers = int(len(encoder_var_names) / self.var_per_encoder)
for i in range(self.n_tune_layers):
trainable_layers.append(f"encoder/layer_{str(n_encoder_layers - 1 - i)}/")
# Add module variables to layer's trainable weights
for var in self.bert.variables:
if any([l in var.name for l in trainable_layers]):
self._trainable_weights.append(var)
else:
self._non_trainable_weights.append(var)
if self.verbose:
print("*** TRAINABLE VARS *** ")
for var in self._trainable_weights:
print(var)
self.build_preprocessor()
self.initialize_module()
super(BertLayer, self).build(input_shape)
def build_preprocessor(self):
sess = tf.keras.backend.get_session()
tokenization_info = self.bert(signature="tokenization_info", as_dict=True)
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
tokenization_info["do_lower_case"]])
self.preprocessor = build_preprocessor(vocab_file, self.seq_len, do_lower_case)
def initialize_module(self):
sess = tf.keras.backend.get_session()
uninitialized = []
for var in self.bert.variables:
if not sess.run(tf.is_variable_initialized(var)):
uninitialized.append(var)
if len(uninitialized):
sess.run(tf.variables_initializer(uninitialized))
def call(self, input):
features = tf.numpy_function(self.preprocessor, [input], [tf.int32, tf.int32, tf.int32])
for feature in features:
feature.set_shape((None, self.seq_len))
input_ids, input_mask, segment_ids = features
bert_inputs = dict(
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
)
output = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)
if self.pooling == "cls":
pooled = output["pooled_output"]
else:
result = output["sequence_output"]
input_mask = tf.cast(input_mask, tf.float32)
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
if self.pooling == "mean":
pooled = masked_reduce_mean(result, input_mask)
else:
pooled = mul_mask(result, input_mask)
return pooled
def get_config(self):
config_dict = {
"bert_path": self.bert_path,
"seq_len": self.seq_len,
"pooling": self.pooling,
"n_tune_layers": self.n_tune_layers,
"tune_embeddings": self.tune_embeddings,
"verbose": self.verbose
}
super(BertLayer, self).get_config()
return config_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment