Skip to content

Instantly share code, notes, and snippets.

@galtay
Created May 3, 2020 23:28
Show Gist options
  • Save galtay/10852bb03b354b2562997973bc29c679 to your computer and use it in GitHub Desktop.
Save galtay/10852bb03b354b2562997973bc29c679 to your computer and use it in GitHub Desktop.
simpletransformers multilabel metrics example
import logging
import pandas as pd
from simpletransformers.classification import MultiLabelClassificationModel
import sklearn.metrics
texts = [
["this is class zero", [1, 0, 0]],
["this is class one", [0, 1, 0]],
["this is class two", [0, 0, 1]],
["this is class zero and one", [1, 1, 0]],
["this is class one and two", [0, 1, 1]],
["this is class two and zero", [1, 0, 1]],
["this is class all 3 classes", [1, 1, 1]],
]
train_df = pd.DataFrame(texts, columns=["text", "labels"])
eval_df = train_df.copy()
model_args = {
"overwrite_output_dir": True,
}
model = MultiLabelClassificationModel(
"roberta",
"distilroberta-base",
num_labels=3,
args=model_args,
)
train_args = {
"evaluate_during_training_steps": 1,
"evaluate_during_training_verbose": True,
"logging_steps": 1,
"train_batch_size": 1,
"eval_batch_size": 1,
"num_train_epochs": 4,
}
eval_metrics = {
# "f1": sklearn.metrics.f1_score,
"roc_auc": sklearn.metrics.roc_auc_score,
"avg_prc": sklearn.metrics.average_precision_score,
}
model.train_model(
train_df,
eval_df=eval_df,
args=train_args,
**eval_metrics,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment