Skip to content

Instantly share code, notes, and snippets.

@YugeTen
Created July 13, 2022 13:41
Show Gist options
  • Save YugeTen/42e3b2a15144f262f27b25ad2fad6a9b to your computer and use it in GitHub Desktop.
Save YugeTen/42e3b2a15144f262f27b25ad2fad6a9b to your computer and use it in GitHub Desktop.
@YugeTen
Copy link
Author

YugeTen commented Jul 25, 2022

    def save(self, trainer: pl.Trainer):
        """Saves current checkpoint.
        Args:
            trainer (pl.Trainer): pytorch lightning trainer object.
        """
        if trainer.is_global_zero and not trainer.sanity_checking:
            epoch = trainer.current_epoch  # type: ignore
            ckpt_path = os.path.join(self.path, self.ckpt_placeholder.format(epoch))
            checkpoint = trainer.checkpoint_connector.dump_checkpoint(weights_only=False)
            trainer.accelerator.save_checkpoint(checkpoint, ckpt_path)

based on: https://github.com/YugeTen/adios/blob/main/src/utils/checkpointer.py#L84

@eleanortrollope
Copy link

eleanortrollope commented Jul 26, 2022

'''

def save(self, trainer: pl.Trainer):
    """Saves current checkpoint.

    Args:
        trainer (pl.Trainer): pytorch lightning trainer object.
    """
    if trainer.is_global_zero and not trainer.sanity_checking:
        epoch = trainer.current_epoch  # type: ignore
        ckpt_path = os.path.join(self.path, self.ckpt_placeholder.format(epoch))
        checkpoint = trainer.checkpoint_connector.dump_checkpoint(weights_only=True)
        trainer.accelerator.save_checkpoint(checkpoint, ckpt_path)

        if self.last_ckpt and self.last_ckpt != ckpt_path and not self.keep_previous_checkpoints:
            os.remove(self.last_ckpt)
        self.last_ckpt = ckpt_path

'''

@eleanortrollope
Copy link

ERROR:

/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py:514: LightningDeprecationWarning: Accelerator.save_checkpoint is deprecated in v1.5 and will be removed in v1.6. save_checkpoint logic is implemented directly in the TrainingTypePlugin implementations.
rank_zero_deprecation(
Traceback (most recent call last):
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 151, in run
output = self.on_run_end()
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 298, in on_run_end
self.trainer.call_hook("on_train_epoch_end")
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
callback_fx(*args, **kwargs)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 93, in on_train_epoch_end
callback.on_train_epoch_end(self, self.lightning_module)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 321, in on_train_epoch_end
self.save_checkpoint(trainer)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 390, in save_checkpoint
self._save_none_monitor_checkpoint(trainer, monitor_candidates)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 687, in _save_none_monitor_checkpoint
trainer.save_checkpoint(filepath, self.save_weights_only)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1914, in save_checkpoint
self.checkpoint_connector.save_checkpoint(filepath, weights_only)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 472, in save_checkpoint
self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 294, in save_checkpoint
return self.checkpoint_io.save_checkpoint(checkpoint, filepath)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/plugins/io/torch_plugin.py", line 37, in save_checkpoint
atomic_save(checkpoint, path)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/pytorch_lightning/utilities/cloud_io.py", line 68, in atomic_save
torch.save(checkpoint, bytesbuffer)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/homes/55/eleanor/anaconda3/lib/python3.9/site-packages/torch/serialization.py", line 589, in _save
pickler.dump(obj)
TypeError: cannot pickle 'generator' object

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment