Created
December 28, 2023 03:32
-
-
Save e96031413/e408f2e6f08cc601a711fa28f2e86c9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def resume_train(self, model): | |
if self.args.resume: | |
logger.info("resume training") | |
if self.args.ckpt is None: | |
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth") | |
else: | |
ckpt_file = self.args.ckpt | |
ckpt = torch.load(ckpt_file, map_location=self.device) | |
# resume the model/optimizer state dict | |
model.load_state_dict(ckpt["model"]) | |
self.optimizer.load_state_dict(ckpt["optimizer"]) | |
# resume the training states variables | |
start_epoch = (self.args.start_epoch - 1 if self.args.start_epoch is not None else ckpt["start_epoch"]) | |
self.start_epoch = start_epoch | |
logger.info("loaded checkpoint '{}' (epoch {})".format(self.args.resume, self.start_epoch)) # noqa | |
else: | |
if self.args.ckpt is not None: | |
logger.info("loading checkpoint for fine tuning") | |
ckpt_file = self.args.ckpt | |
ckpt = torch.load(ckpt_file, map_location=self.device)["model"] | |
model = load_ckpt(model, ckpt) | |
self.start_epoch = 0 | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment