Created
May 3, 2020 23:28
-
-
Save galtay/10852bb03b354b2562997973bc29c679 to your computer and use it in GitHub Desktop.
simpletransformers multilabel metrics example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pandas as pd | |
from simpletransformers.classification import MultiLabelClassificationModel | |
import sklearn.metrics | |
texts = [ | |
["this is class zero", [1, 0, 0]], | |
["this is class one", [0, 1, 0]], | |
["this is class two", [0, 0, 1]], | |
["this is class zero and one", [1, 1, 0]], | |
["this is class one and two", [0, 1, 1]], | |
["this is class two and zero", [1, 0, 1]], | |
["this is class all 3 classes", [1, 1, 1]], | |
] | |
train_df = pd.DataFrame(texts, columns=["text", "labels"]) | |
eval_df = train_df.copy() | |
model_args = { | |
"overwrite_output_dir": True, | |
} | |
model = MultiLabelClassificationModel( | |
"roberta", | |
"distilroberta-base", | |
num_labels=3, | |
args=model_args, | |
) | |
train_args = { | |
"evaluate_during_training_steps": 1, | |
"evaluate_during_training_verbose": True, | |
"logging_steps": 1, | |
"train_batch_size": 1, | |
"eval_batch_size": 1, | |
"num_train_epochs": 4, | |
} | |
eval_metrics = { | |
# "f1": sklearn.metrics.f1_score, | |
"roc_auc": sklearn.metrics.roc_auc_score, | |
"avg_prc": sklearn.metrics.average_precision_score, | |
} | |
model.train_model( | |
train_df, | |
eval_df=eval_df, | |
args=train_args, | |
**eval_metrics, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment