Skip to content

Instantly share code, notes, and snippets.

@e96031413
Created December 28, 2023 03:32
Show Gist options
  • Save e96031413/e408f2e6f08cc601a711fa28f2e86c9e to your computer and use it in GitHub Desktop.
Save e96031413/e408f2e6f08cc601a711fa28f2e86c9e to your computer and use it in GitHub Desktop.
def resume_train(self, model):
if self.args.resume:
logger.info("resume training")
if self.args.ckpt is None:
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth")
else:
ckpt_file = self.args.ckpt
ckpt = torch.load(ckpt_file, map_location=self.device)
# resume the model/optimizer state dict
model.load_state_dict(ckpt["model"])
self.optimizer.load_state_dict(ckpt["optimizer"])
# resume the training states variables
start_epoch = (self.args.start_epoch - 1 if self.args.start_epoch is not None else ckpt["start_epoch"])
self.start_epoch = start_epoch
logger.info("loaded checkpoint '{}' (epoch {})".format(self.args.resume, self.start_epoch)) # noqa
else:
if self.args.ckpt is not None:
logger.info("loading checkpoint for fine tuning")
ckpt_file = self.args.ckpt
ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
model = load_ckpt(model, ckpt)
self.start_epoch = 0
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment