Skip to content

Instantly share code, notes, and snippets.

@kusal1990
Created June 25, 2022 06:51
Show Gist options
  • Save kusal1990/2773c2b8d2ab5474b1fb77e7542eb19d to your computer and use it in GitHub Desktop.
Save kusal1990/2773c2b8d2ab5474b1fb77e7542eb19d to your computer and use it in GitHub Desktop.
pre_trained_model = TFRobertaForMultipleChoice.from_pretrained('roberta-base')
model_input_ids = Input(shape=(5,128,), name='input_tokens', dtype='int32')
masks_input = Input(shape=(5,128,), name='attention_mask', dtype='int32')
x = {'input_ids':model_input_ids,
'attention_mask':masks_input}
x = pre_trained_model(x)['logits']
outputs = Dense(5, activation='softmax')(x)
model = Model(inputs=[model_input_ids, masks_input], outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
valid_y = tf.keras.utils.to_categorical(easy_dev_labels, num_classes=5)
model.fit(x=[easy_train_input_ids, easy_train_attention_mask], y=tf.keras.utils.to_categorical(easy_train_labels),
batch_size=8, epochs=5,
validation_data=([easy_dev_input_ids, easy_dev_attention_mask], valid_y),
validation_batch_size=8)
@kusal1990
Copy link
Author

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment