Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active August 7, 2023 18:23
Show Gist options
  • Select an option

  • Save woshiyyya/9792934a0cb6a199e5a9eaa5aaf9a687 to your computer and use it in GitHub Desktop.

Select an option

Save woshiyyya/9792934a0cb6a199e5a9eaa5aaf9a687 to your computer and use it in GitHub Desktop.

PyTorch Lightning User Guides

Converting an existing training loop

You should replace the arguments in pl.Trainer with Ray Train's implementations.

import pytorch_lightning as pl
+ from ray.train.lightning import (
+     get_devices,
+     prepare_trainer,
+     RayDDPStrategy,
+     RayLightningEnvironment,
+ )

def train_func():
    model = MyLightningModule()
    datamodule = MyLightningDataModule()
    
    trainer = pl.Trainer(
      - devices=[0,1,2,3],
      - strategy=DDPStrategy(),
      + devices=get_devices(),
      + strategy=RayDDPStrategy(),
      + plugins=[RayLightningEnvironment()]
    )
    + trainer = prepare_trainer(trainer)
    
    trainer.fit(model, datamodule=datamodule)

Data Loading and Preprocessing

How to differentiate Ray Data and Torch DAtaLoader?

Reporting results to Ray Train

PyTorch Lightning provides callback interface to inject customized logics during training. You can report model checkpoints to Ray Train by implementing a customized callback.

For example, here we implemeted a RayTrainReportCallback, which reports metrics and checkpoints on every epoch end. You can also report at different stages(e.g. on_train_batch_end, on_validation_epoch_end) or at different frequency (every 2 epochs, every 100 batches).

from pytorch_lightning.callbacks import Callback
from ray.train import Checkpoint

class RayTrainReportCallback(Callback):
    def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        with TemporaryDirectory() as tmpdir:
            # Fetch metrics
            metrics = trainer.callback_metrics
            metrics = {k: v.item() for k, v in metrics.items()}
           
            # Save checkpoint to local
            ckpt_path = os.path.join(tmpdir, f"ckpt_epoch_{trainer.current_epoch}.pth")
            trainer.save_checkpoint(ckpt_path, weights_only=False)

            # Report to train session
            checkpoint = Checkpoint.from_directory(tmpdir)
            ray.train.report(metrics=metrics, checkpoint=checkpoint)
 
def train_func():
    ...
    trainer = Trainer(
       ...
       callbacks=[RayTrainReportCallback()]  
    )

By default, Ray Train saves all checkpoints that you report. If you want to keep only the top-k checkpoints based on a certain metric, you can specify a CheckpointConfig in the ray.train.torch.TorchTrainer. Here, we select the top-2 checkpoints with the highest validation accuracy.

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        )
    )
)

Reporting checkpoints to Ray Train enables your training to benefit from various advanced Ray features, including fault tolerance and distributed checkpointing. Please refer to [] and [] to learn more about these features.

Saving and loading checkpoints

Issues:

  • The saving part seems to be a little duplicated from the previous section.
  • The title looks ambiguous, should be reporting and retrieving checkpoints from training loops.

Loading Checkpoint

def train_func():
    checkpoint = ray.train.get_checkpoint()

    # Resume training from the latest reported checkpoint
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            ckpt_path = os.path.join(checkpoint_dir, "path/of/ckpt/file")
            trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
    # Start training from scratch
    else:
        trainer.fit(model, datamodule=datamodule)

Experiment tracking and Callbacks

W&B, CometML, MLFlow, and Tensorboard are popular tools for managing, visualizing, and tracking ML experiments. If you're using PyTorch Lightning, you can continue to use its built-in logger integrations with Ray Train.

from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.loggers.mlflow import MLFlowLogger

def train_func():
    wandb_logger = WandbLogger(
        name="demo-run", 
        project="demo-project", 
        id="unique_id",  # Specify a unique id to avoid creating a new run after restoration
        offline=offline
    )
    
    comet_logger = CometLogger(
        api_key=comet_api_key,
        experiment_name="demo-experiment",
        project_name="demo-project,
        offline=offline,
    )
    
    mlflow_logger = MLFlowLogger(
        run_name=name,
        experiment_name=project_name,
        tracking_uri=f"file:{save_dir}/mlflow",
    )
    
    trainer = pl.Trainer(
        ...,
        logger=[wandb_logger, comet_logger, mlflow_logger],
    )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment