Last active
July 28, 2023 00:14
-
-
Save john-adeojo/b6fa7b2256a2cf2274388959dcd3bf25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
import evaluate | |
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score | |
f1_metric = evaluate.load("f1") | |
recall_metric = evaluate.load("recall") | |
accuracy_metric = evaluate.load("accuracy") | |
precision_metric = evaluate.load("precision") | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
predictions = np.argmax(logits, axis=-1) | |
results = {} | |
results.update(f1_metric.compute(predictions=predictions, references = labels, average="macro")) | |
results.update(recall_metric.compute(predictions=predictions, references = labels, average="macro")) | |
results.update(accuracy_metric.compute(predictions=predictions, references = labels)) | |
results.update(precision_metric.compute(predictions=predictions, references = labels, average="macro")) | |
return results | |
trainer = transformers.Trainer( | |
model=lora_model, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=compute_metrics, | |
args=transformers.TrainingArguments( | |
per_device_train_batch_size=8, | |
per_device_eval_batch_size=32, | |
gradient_accumulation_steps=4, | |
warmup_steps=100, | |
max_steps=12276, | |
learning_rate=2e-4, | |
fp16=True, | |
eval_steps= 1000, | |
logging_steps=1000, | |
save_steps=1000, | |
evaluation_strategy="steps", | |
do_eval=True, | |
load_best_model_at_end=True, | |
metric_for_best_model="f1", | |
output_dir='model_outputs', | |
logging_dir='model_outputs', | |
remove_unused_columns =False, | |
report_to='wandb' # enable logging to W&B | |
), | |
) | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment