Skip to content

Instantly share code, notes, and snippets.

@Akashdesarda
Created April 23, 2020 09:35
Show Gist options
  • Save Akashdesarda/6eca2e9ba3babe2469a3682b56b3ecff to your computer and use it in GitHub Desktop.
Save Akashdesarda/6eca2e9ba3babe2469a3682b56b3ecff to your computer and use it in GitHub Desktop.
distil_bert = 'distilbert-base-uncased'
config = DistilBertConfig(dropout=0.2, attention_dropout=0.2)
config.output_hidden_states = False
transformer_model = TFDistilBertModel.from_pretrained(distil_bert, config = config)
input_ids_in = tf.keras.layers.Input(shape=(128,), name='input_token', dtype='int32')
input_masks_in = tf.keras.layers.Input(shape=(128,), name='masked_token', dtype='int32')
embedding_layer = transformer_model(input_ids_in, attention_mask=input_masks_in)[0]
cls_token = embedding_layer[:,0,:]
X = tf.keras.layers.BatchNormalization()(cls_token)
X = tf.keras.layers.Dense(192, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)
X = tf.keras.layers.Dense(6, activation='softmax')(X)
model = tf.keras.Model(inputs=[input_ids_in, input_masks_in], outputs = X)
for layer in model.layers[:3]:
layer.trainable = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment