Skip to content

Instantly share code, notes, and snippets.

@Akashdesarda
Last active April 29, 2020 02:21
Show Gist options
  • Save Akashdesarda/4e069ae2c0fc375b0108ce1d9e643b40 to your computer and use it in GitHub Desktop.
Save Akashdesarda/4e069ae2c0fc375b0108ce1d9e643b40 to your computer and use it in GitHub Desktop.
from transformers import TFDistilBertForSequenceClassification, DistilBertConfig
import tensorflow as tf
distil_bert = 'distilbert-base-uncased'
config = DistilBertConfig(num_labels=6)
config.output_hidden_states = False
transformer_model = TFDistilBertForSequenceClassification.from_pretrained(distil_bert, config = config)[0]
input_ids = tf.keras.layers.Input(shape=(128,), name='input_token', dtype='int32')
input_masks_ids = tf.keras.layers.Input(shape=(128,), name='masked_token', dtype='int32')
X = transformer_model(input_ids, input_masks_ids)
model = tf.keras.Model(inputs=[input_ids, input_masks_ids], outputs = X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment