Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Created August 8, 2023 04:42
Show Gist options
  • Select an option

  • Save woshiyyya/f9cfa1a50598be2f2d0a2c40a1cbe21a to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/f9cfa1a50598be2f2d0a2c40a1cbe21a to your computer and use it in GitHub Desktop.
Run Transformers Trainer with Ray TorchTrainer
import os
import evaluate
import numpy as np
from datasets import load_dataset
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, Checkpoint
from ray.train.torch import TorchTrainer
from transformers import AutoTokenizer
from transformers import (
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer,
)
import ray
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
# Report checkpoints to Ray Train
class RayTrainReportCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
"""Event called after a checkpoint save."""
metrics = state.log_history[-1] if state.log_history else {}
checkpoint = Checkpoint.from_directory(get_last_checkpoint(args.output_dir))
ray.train.report(metrics=metrics, checkpoint=checkpoint)
def train_loop_per_worker():
os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY"
hf_ds = load_dataset("tweet_eval", "irony")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize(examples):
return tokenizer(
examples["text"],
max_length=256,
truncation=True,
padding="max_length",
)
train_ds = hf_ds["train"].map(tokenize, batched=True)
test_ds = hf_ds["test"].map(tokenize, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
accuracy = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return accuracy.compute(predictions=predictions, references=labels)
# Load pretrained model
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", num_labels=2
)
# Define Transformers Trainer
training_args = TrainingArguments(
output_dir="/mnt/cluster_storage/hf_results",
learning_rate=1e-4,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=4,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
push_to_hub=False,
report_to="wandb",
metric_for_best_model="eval_accuracy",
save_total_limit=2,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# This will be available in Ray 2.7
# trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
trainer.add_callback(RayTrainReportCallback())
# Train your model
trainer.train()
if __name__ == "__main__":
s3_bucket = os.environ["ANYSCALE_ARTIFACT_STORAGE"]
ray_trainer = TorchTrainer(
train_loop_per_worker,
run_config=RunConfig(
name="exp",
storage_path=f"{s3_bucket}/ray_results",
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="eval_accuracy",
checkpoint_score_order="max",
),
),
scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
)
ray_trainer.fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment