Created
January 20, 2022 09:04
-
-
Save GuillaumeLeclerc/49faabbe399c6cf21cfb8e9711249e10 to your computer and use it in GitHub Desktop.
PTL/FFCV example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
from pytorch_lightning.utilities.fetching import AbstractDataFetcher | |
from pytorch_lightning.utilities.model_helpers import is_overridden | |
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature | |
| |
| |
def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None: | |
self.trainer.logger_connector.on_epoch_start() | |
self.trainer.call_hook("on_epoch_start") | |
self.trainer.call_hook("on_train_epoch_start") | |
self.trainer.fit_loop.epoch_progress.increment_started() | |
self._dataloader_iter = enumerate(iter(data_fetcher.dataloader.loaders)) | |
| |
def advance(self, *args: Any, **kwargs: Any) -> None: | |
"""Runs a single training batch. | |
| |
Args: | |
dataloader_iter: the iterator over the dataloader producing the new batch | |
| |
Raises: | |
StopIteration: When the epoch is canceled by the user returning -1 | |
""" | |
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch): | |
# skip training and run validation in `on_advance_end` | |
return | |
| |
try: | |
batch_idx, (batch) = next(self._dataloader_iter) | |
self.batch_progress.is_last_batch = False | |
except StopIteration: | |
self.batch_progress.is_last_batch = True | |
raise StopIteration | |
| |
if not self.trainer._data_connector.train_data_fetcher.store_on_device: | |
with self.trainer.profiler.profile("training_batch_to_device"): | |
batch = self.trainer.accelerator.batch_to_device(batch) | |
| |
self.batch_progress.increment_ready() | |
| |
self.trainer.logger_connector.on_batch_start(batch_idx, batch) | |
| |
if batch is None: | |
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") | |
batch_output = [] | |
else: | |
# hook | |
response = self.trainer.call_hook("on_batch_start") | |
if response == -1: | |
self.batch_progress.increment_processed() | |
raise StopIteration | |
| |
# TODO: Update this in v1.7 (deprecation: #9816) | |
model_fx = self.trainer.lightning_module.on_train_batch_start | |
extra_kwargs = ( | |
{"dataloader_idx": 0} | |
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) | |
else {} | |
) | |
| |
# hook | |
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) | |
if response == -1: | |
self.batch_progress.increment_processed() | |
raise StopIteration | |
| |
self.batch_progress.increment_started() | |
| |
with self.trainer.profiler.profile("run_training_batch"): | |
batch_output = self.batch_loop.run(batch, batch_idx) | |
| |
self.batch_progress.increment_processed() | |
| |
# update non-plateau LR schedulers | |
# update epoch-interval ones only when we are at the end of training epoch | |
self.update_lr_schedulers("step", update_plateau_schedulers=False) | |
if self._num_ready_batches_reached(): | |
self.update_lr_schedulers("epoch", update_plateau_schedulers=False) | |
| |
batch_end_outputs = self._prepare_outputs_training_batch_end( | |
batch_output, | |
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, | |
num_optimizers=len(self.trainer.optimizers), | |
) | |
| |
# TODO: Update this in v1.7 (deprecation: #9816) | |
model_fx = self.trainer.lightning_module.on_train_batch_end | |
extra_kwargs = ( | |
{"dataloader_idx": 0} | |
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) | |
else {} | |
) | |
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs) | |
self.trainer.call_hook("on_batch_end") | |
self.trainer.logger_connector.on_batch_end() | |
| |
self.batch_progress.increment_completed() | |
| |
if is_overridden("training_epoch_end", self.trainer.lightning_module): | |
self._outputs.append(batch_output) | |
| |
# ----------------------------------------- | |
# SAVE METRICS TO LOGGERS AND PROGRESS_BAR | |
# ----------------------------------------- | |
self.trainer.logger_connector.update_train_step_metrics() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import MethodType | |
import custom_PTL_methods | |
. | |
. | |
. | |
trainer = pl.Trainer(...) | |
trainer.fit_loop.epoch_loop.on_run_start = MethodType(custom_PTL_methods.on_run_start, trainer.fit_loop.epoch_loop) | |
trainer.fit_loop.epoch_loop.advance = MethodType(custom_PTL_methods.advance, trainer.fit_loop.epoch_loop) | |
trainer.fit(model, train_dataloaders=ffcv_train_loader, val_dataloaders=ffcv_val_loader) |
Not sure, Some other people in my lab were using PTL I don't. I don't think they were using DDP :/ I think your best bet is to try it yourself :/
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, trying to combine PTL and FFCV in my project. Does this walk around support DDP training? Do I need to set ffcv_val_loader(distribute=True) ? Thanks!