You should replace the arguments in pl.Trainer with Ray Train's implementations.
import pytorch_lightning as pl
+ from ray.train.lightning import (
+ get_devices,
+ prepare_trainer,
+ RayDDPStrategy,
+ RayLightningEnvironment,
+ )
def train_func():
model = MyLightningModule()
datamodule = MyLightningDataModule()
trainer = pl.Trainer(
- devices=[0,1,2,3],
- strategy=DDPStrategy(),
+ devices=get_devices(),
+ strategy=RayDDPStrategy(),
+ plugins=[RayLightningEnvironment()]
)
+ trainer = prepare_trainer(trainer)
trainer.fit(model, datamodule=datamodule)How to differentiate Ray Data and Torch DAtaLoader?
PyTorch Lightning provides callback interface to inject customized logics during training. You can report model checkpoints to Ray Train by implementing a customized callback.
For example, here we implemeted a RayTrainReportCallback, which reports metrics and checkpoints on every epoch end. You can also report at different stages(e.g. on_train_batch_end, on_validation_epoch_end) or at different frequency (every 2 epochs, every 100 batches).
from pytorch_lightning.callbacks import Callback
from ray.train import Checkpoint
class RayTrainReportCallback(Callback):
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
with TemporaryDirectory() as tmpdir:
# Fetch metrics
metrics = trainer.callback_metrics
metrics = {k: v.item() for k, v in metrics.items()}
# Save checkpoint to local
ckpt_path = os.path.join(tmpdir, f"ckpt_epoch_{trainer.current_epoch}.pth")
trainer.save_checkpoint(ckpt_path, weights_only=False)
# Report to train session
checkpoint = Checkpoint.from_directory(tmpdir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)
def train_func():
...
trainer = Trainer(
...
callbacks=[RayTrainReportCallback()]
)By default, Ray Train saves all checkpoints that you report. If you want to keep only the top-k checkpoints based on a certain metric, you can specify a CheckpointConfig in the ray.train.torch.TorchTrainer. Here, we select the top-2 checkpoints with the highest validation accuracy.
trainer = TorchTrainer(
train_loop_per_worker=train_func,
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
)
)
)Reporting checkpoints to Ray Train enables your training to benefit from various advanced Ray features, including fault tolerance and distributed checkpointing. Please refer to [] and [] to learn more about these features.
Issues:
- The saving part seems to be a little duplicated from the previous section.
- The title looks ambiguous, should be reporting and retrieving checkpoints from training loops.
Loading Checkpoint
def train_func():
checkpoint = ray.train.get_checkpoint()
# Resume training from the latest reported checkpoint
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
ckpt_path = os.path.join(checkpoint_dir, "path/of/ckpt/file")
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
# Start training from scratch
else:
trainer.fit(model, datamodule=datamodule)W&B, CometML, MLFlow, and Tensorboard are popular tools for managing, visualizing, and tracking ML experiments. If you're using PyTorch Lightning, you can continue to use its built-in logger integrations with Ray Train.
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.loggers.mlflow import MLFlowLogger
def train_func():
wandb_logger = WandbLogger(
name="demo-run",
project="demo-project",
id="unique_id", # Specify a unique id to avoid creating a new run after restoration
offline=offline
)
comet_logger = CometLogger(
api_key=comet_api_key,
experiment_name="demo-experiment",
project_name="demo-project,
offline=offline,
)
mlflow_logger = MLFlowLogger(
run_name=name,
experiment_name=project_name,
tracking_uri=f"file:{save_dir}/mlflow",
)
trainer = pl.Trainer(
...,
logger=[wandb_logger, comet_logger, mlflow_logger],
)