Skip to content

Instantly share code, notes, and snippets.

@gaphex
Created November 17, 2019 17:46
Show Gist options
  • Save gaphex/07e2f9f915bbd9dec5caea6550d9c6e7 to your computer and use it in GitHub Desktop.
Save gaphex/07e2f9f915bbd9dec5caea6550d9c6e7 to your computer and use it in GitHub Desktop.
layer build method
def build(self, input_shape):
self.bert = hub.Module(self.bert_path, trainable=self.trainable, name=f"{self.name}_module")
trainable_layers = []
if self.tune_embeddings:
trainable_layers.append("embeddings")
if self.pooling == "cls":
trainable_layers.append("pooler")
if self.n_tune_layers > 0:
encoder_var_names = [var.name for var in self.bert.variables if 'encoder' in var.name]
n_encoder_layers = int(len(encoder_var_names) / self.var_per_encoder)
for i in range(self.n_tune_layers):
trainable_layers.append(f"encoder/layer_{str(n_encoder_layers - 1 - i)}/")
# Add module variables to layer's trainable weights
for var in self.bert.variables:
if any([l in var.name for l in trainable_layers]):
self._trainable_weights.append(var)
else:
self._non_trainable_weights.append(var)
if self.verbose:
print("*** TRAINABLE VARS *** ")
for var in self._trainable_weights:
print(var)
self.build_preprocessor()
self.initialize_module()
super(BertLayer, self).build(input_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment