Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active September 6, 2024 03:19
Show Gist options
  • Save wassname/910538af6b0e6f10b75b3886d79ea6ed to your computer and use it in GitHub Desktop.
Save wassname/910538af6b0e6f10b75b3886d79ea6ed to your computer and use it in GitHub Desktop.
Lightning snippets for use with transformers
from lightning.pytorch.callbacks import ModelCheckpoint
from weakref import proxy
class AdapterModelCheckpoint(ModelCheckpoint):
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
trainer.model.save_pretrained(filepath)
# trainer.save_checkpoint(filepath, self.save_weights_only)
self._last_global_step_saved = trainer.global_step
self._last_checkpoint_saved = filepath
# notify loggers
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
from lightning.pytorch.callbacks import Callback
class GenCallback(Callback):
"""A callback to generate on sample each N samples."""
def __init__(self, every=50):
self.every = every
def do_gen(self, model):
get_model_generations(model, model.tokenizer, max_new_tokens=32, N=1)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
n = batch_idx + 1
if n % self.every == 0:
print(f"Generated on batch {batch_idx}")
self.do_gen(trainer.model._model)
def on_train_epoch_end(self, trainer, pl_module):
self.do_gen(trainer.model._model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment