Last active
November 17, 2019 17:45
-
-
Save gaphex/5ab204c3d2f9a448ea919eb497f9bc10 to your computer and use it in GitHub Desktop.
Keras BERT layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class BertLayer(tf.keras.layers.Layer): | |
| def __init__(self, bert_path, seq_len=64, n_tune_layers=3, | |
| pooling="cls", verbose=False, | |
| tune_embeddings=False, **kwargs): | |
| self.n_tune_layers = n_tune_layers | |
| self.tune_embeddings = tune_embeddings | |
| self.seq_len = seq_len | |
| self.trainable = True | |
| self.verbose = verbose | |
| self.pooling = pooling | |
| self.bert_path = bert_path | |
| self.var_per_encoder = 16 | |
| if self.pooling not in ["cls", "mean", None]: | |
| raise NameError( | |
| f"Undefined pooling type (must be either 'cls', 'mean', or None, but is {self.pooling}" | |
| ) | |
| super(BertLayer, self).__init__(**kwargs) | |
| def build(self, input_shape): | |
| self.bert = hub.Module(self.bert_path, trainable=self.trainable, name=f"{self.name}_module") | |
| trainable_layers = [] | |
| if self.tune_embeddings: | |
| trainable_layers.append("embeddings") | |
| if self.pooling == "cls": | |
| trainable_layers.append("pooler") | |
| if self.n_tune_layers > 0: | |
| encoder_var_names = [var.name for var in self.bert.variables if 'encoder' in var.name] | |
| n_encoder_layers = int(len(encoder_var_names) / self.var_per_encoder) | |
| for i in range(self.n_tune_layers): | |
| trainable_layers.append(f"encoder/layer_{str(n_encoder_layers - 1 - i)}/") | |
| # Add module variables to layer's trainable weights | |
| for var in self.bert.variables: | |
| if any([l in var.name for l in trainable_layers]): | |
| self._trainable_weights.append(var) | |
| else: | |
| self._non_trainable_weights.append(var) | |
| if self.verbose: | |
| print("*** TRAINABLE VARS *** ") | |
| for var in self._trainable_weights: | |
| print(var) | |
| self.build_preprocessor() | |
| self.initialize_module() | |
| super(BertLayer, self).build(input_shape) | |
| def build_preprocessor(self): | |
| sess = tf.keras.backend.get_session() | |
| tokenization_info = self.bert(signature="tokenization_info", as_dict=True) | |
| vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], | |
| tokenization_info["do_lower_case"]]) | |
| self.preprocessor = build_preprocessor(vocab_file, self.seq_len, do_lower_case) | |
| def initialize_module(self): | |
| sess = tf.keras.backend.get_session() | |
| uninitialized = [] | |
| for var in self.bert.variables: | |
| if not sess.run(tf.is_variable_initialized(var)): | |
| uninitialized.append(var) | |
| if len(uninitialized): | |
| sess.run(tf.variables_initializer(uninitialized)) | |
| def call(self, input): | |
| features = tf.numpy_function(self.preprocessor, [input], [tf.int32, tf.int32, tf.int32]) | |
| for feature in features: | |
| feature.set_shape((None, self.seq_len)) | |
| input_ids, input_mask, segment_ids = features | |
| bert_inputs = dict( | |
| input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids | |
| ) | |
| output = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True) | |
| if self.pooling == "cls": | |
| pooled = output["pooled_output"] | |
| else: | |
| result = output["sequence_output"] | |
| input_mask = tf.cast(input_mask, tf.float32) | |
| mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) | |
| masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( | |
| tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) | |
| if self.pooling == "mean": | |
| pooled = masked_reduce_mean(result, input_mask) | |
| else: | |
| pooled = mul_mask(result, input_mask) | |
| return pooled | |
| def get_config(self): | |
| config_dict = { | |
| "bert_path": self.bert_path, | |
| "seq_len": self.seq_len, | |
| "pooling": self.pooling, | |
| "n_tune_layers": self.n_tune_layers, | |
| "tune_embeddings": self.tune_embeddings, | |
| "verbose": self.verbose | |
| } | |
| super(BertLayer, self).get_config() | |
| return config_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment