Created
May 20, 2021 15:30
-
-
Save ThilinaRajapakse/d8b3045ccc6bfd83c229b0b178fa3c6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pandas as pd | |
from simpletransformers.classification import ( | |
MultiLabelClassificationModel, | |
MultiLabelClassificationArgs, | |
) | |
logging.basicConfig(level=logging.INFO) | |
transformers_logger = logging.getLogger("transformers") | |
transformers_logger.setLevel(logging.WARNING) | |
# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. The first column is the text with type str, and the second column in the label with type int. | |
train_data = [["Example sentence 1 for multilabel", [1, 1, 1, 1, 0, 1]]] + [ | |
["This thing is entirely different from the other thing. ", [0, 1, 1, 0, 0, 0]] | |
] | |
train_df = pd.DataFrame(train_data, columns=["text", "labels"]) | |
eval_data = [ | |
["Example sentence belonging to class 1", [1, 1, 1, 1, 0, 1]], | |
[ | |
"This thing should be entirely different from the other thing. ", | |
[0, 0, 0, 0, 1, 0], | |
], | |
] | |
eval_df = pd.DataFrame(eval_data, columns=["text", "labels"]) | |
model_args = MultiLabelClassificationArgs() | |
model_args.train_custom_parameters_only = True | |
model_args.reprocess_input_data = True | |
model_args.overwrite_output_dir = True | |
model_args.learning_rate = 3e-3 | |
model_args.custom_parameter_groups = [ | |
{ | |
"params": ["classifier.weight"], | |
"lr": 1e-3, | |
}, | |
{ | |
"params": ["classifier.bias"], | |
"lr": 1e-3, | |
"weight_decay": 0.0, | |
}, | |
{ | |
"params": ["classifier_1.weight"], | |
"lr": 1e-3, | |
}, | |
{ | |
"params": ["classifier_1.bias"], | |
"lr": 1e-3, | |
"weight_decay": 0.0, | |
}, | |
] | |
# train_args = { | |
# "reprocess_input_data": True, | |
# "overwrite_output_dir": True, | |
# "evaluate_during_training": True, | |
# "evaluate_during_training_verbose": True, | |
# "logging_steps": 50, | |
# "num_train_epochs": 1, | |
# # "evaluate_during_training_steps": 1, | |
# "best_model_dir": "TESTING BEST", | |
# "wandb_project": "test-new-project", | |
# "no_cache": True, | |
# "use_early_stopping": True, | |
# "max_seq_length": 3, | |
# # "sliding_window": True, | |
# # "fp16": False, | |
# } | |
# Create a MultiLabelClassificationModel | |
model = MultiLabelClassificationModel( | |
"roberta", | |
"pdelobelle/robbert-v2-dutch-base", | |
num_labels=6, | |
use_cuda=False, | |
args=model_args, | |
cuda_device=1, | |
) | |
# Train the model | |
model.train_model(train_df) | |
# # Evaluate the model | |
result, model_outputs, wrong_predictions = model.eval_model(eval_df) | |
print(result) | |
# print("\nmodel_outputs from model.eval_model()") | |
# print(model_outputs) | |
predictions, raw_outputs = model.predict( | |
[ | |
"Example sentence belonging to class 1", | |
"This thing should be entirely different from the other thing.", | |
"This thing should be entirely different from the other thing.", | |
] | |
) | |
print(predictions) | |
print(raw_outputs) | |
exit() | |
predictions, raw_outputs = model.predict( | |
[ | |
"Example sentence belonging to class 1", | |
"This thing should be entirely different from the other thing.", | |
] | |
) | |
print("\npredictions from model.predict()") | |
print(predictions) | |
print("\nraw_outputs from model.predict()") | |
print(raw_outputs) | |
model.args["threshold"] = [0.9, 0.6, 0.5, 0.3, 0.2, 0.5] | |
print("\nAfter variable thresholds.") | |
predictions, raw_outputs = model.predict( | |
[ | |
"Example sentence belonging to class 1", | |
"This thing should be entirely different from the other thing.", | |
] | |
) | |
print("\npredictions from model.predict()") | |
print(predictions) | |
print("\nraw_outputs from model.predict()") | |
print(raw_outputs) | |
# ### SAVE AND LOAD TEST ### | |
# model = MultiLabelClassificationModel('albert', 'outputs/checkpoint-5', use_cuda=False, pos_weight=[1, 1, 1, 1, 1, 1], args=train_args) | |
# # Evaluate the model | |
# result, model_outputs, wrong_predictions = model.eval_model(eval_df) | |
# print(result) | |
# print('\nmodel_outputs from model.eval_model()') | |
# print(model_outputs) | |
# predictions, raw_outputs = model.predict(['Example sentence belonging to class 1', 'This thing should be entirely different from the other thing.']) | |
# print('\npredictions from model.predict()') | |
# print(predictions) | |
# print('\nraw_outputs from model.predict()') | |
# print(raw_outputs) | |
# model.args['threshold'] = [0.9, 0.6, 0.5, 0.3, 0.2, 0.5] | |
# print('\nAfter variable thresholds.') | |
# predictions, raw_outputs = model.predict(['Example sentence belonging to class 1', 'This thing should be entirely different from the other thing.']) | |
# print('\npredictions from model.predict()') | |
# print(predictions) | |
# print('\nraw_outputs from model.predict()') | |
# print(raw_outputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment