Created
September 27, 2024 14:34
-
-
Save shivammehta25/a0a19c52f681071f79547aedc79a0e8a to your computer and use it in GitHub Desktop.
This is Matcha's train.py if you only want to load weights from the checkpoint and not the optimizer's states.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict, List, Optional, Tuple | |
import hydra | |
import lightning as L | |
import rootutils | |
from lightning import Callback, LightningDataModule, LightningModule, Trainer | |
from lightning.pytorch.loggers import Logger | |
from omegaconf import DictConfig | |
from matcha import utils | |
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
# ------------------------------------------------------------------------------------ # | |
# the setup_root above is equivalent to: | |
# - adding project root dir to PYTHONPATH | |
# (so you don't need to force user to install project as a package) | |
# (necessary before importing any local modules e.g. `from src import utils`) | |
# - setting up PROJECT_ROOT environment variable | |
# (which is used as a base for paths in "configs/paths/default.yaml") | |
# (this way all filepaths are the same no matter where you run the code) | |
# - loading environment variables from ".env" in root dir | |
# | |
# you can remove it if you: | |
# 1. either install project as a package or move entry files to project root dir | |
# 2. set `root_dir` to "." in "configs/paths/default.yaml" | |
# | |
# more info: https://github.com/ashleve/rootutils | |
# ------------------------------------------------------------------------------------ # | |
log = utils.get_pylogger(__name__) | |
@utils.task_wrapper | |
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: | |
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during | |
training. | |
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during | |
failure. Useful for multiruns, saving info about the crash, etc. | |
:param cfg: A DictConfig configuration composed by Hydra. | |
:return: A tuple with metrics and dict with all instantiated objects. | |
""" | |
# set seed for random number generators in pytorch, numpy and python.random | |
if cfg.get("seed"): | |
L.seed_everything(cfg.seed, workers=True) | |
log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access | |
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) | |
log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access | |
model: LightningModule = hydra.utils.instantiate(cfg.model) | |
log.info("Instantiating callbacks...") | |
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) | |
log.info("Instantiating loggers...") | |
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) | |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access | |
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) | |
object_dict = { | |
"cfg": cfg, | |
"datamodule": datamodule, | |
"model": model, | |
"callbacks": callbacks, | |
"logger": logger, | |
"trainer": trainer, | |
} | |
if logger: | |
log.info("Logging hyperparameters!") | |
utils.log_hyperparameters(object_dict) | |
ckpt_path = cfg.get("ckpt_path") | |
if ckpt_path is not None and cfg.get("load_only_weights", False): | |
log.info("Logging only weights from the checkpoint ignoring all optimizer states") | |
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']) | |
ckpt_path = None | |
if cfg.get("train"): | |
log.info("Starting training!") | |
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) | |
train_metrics = trainer.callback_metrics | |
if cfg.get("test"): | |
log.info("Starting testing!") | |
ckpt_path = trainer.checkpoint_callback.best_model_path | |
if ckpt_path == "": | |
log.warning("Best ckpt not found! Using current weights for testing...") | |
ckpt_path = None | |
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) | |
log.info(f"Best ckpt path: {ckpt_path}") | |
test_metrics = trainer.callback_metrics | |
# merge train and test metrics | |
metric_dict = {**train_metrics, **test_metrics} | |
return metric_dict, object_dict | |
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") | |
def main(cfg: DictConfig) -> Optional[float]: | |
"""Main entry point for training. | |
:param cfg: DictConfig configuration composed by Hydra. | |
:return: Optional[float] with optimized metric value. | |
""" | |
# apply extra utilities | |
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) | |
utils.extras(cfg) | |
# train the model | |
metric_dict, _ = train(cfg) | |
# safely retrieve metric value for hydra-based hyperparameter optimization | |
metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) | |
# return optimized metric | |
return metric_value | |
if __name__ == "__main__": | |
main() # pylint: disable=no-value-for-parameter |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment