Skip to content

Instantly share code, notes, and snippets.

@yptheangel
Created April 5, 2020 08:54
Show Gist options
  • Save yptheangel/d3c26b82fd7aaedbf179ba0c984fdba7 to your computer and use it in GitHub Desktop.
Save yptheangel/d3c26b82fd7aaedbf179ba0c984fdba7 to your computer and use it in GitHub Desktop.
Load DL4J model checkpoint and continue training.
// Train your model and save the model if the training loss is better than the previous. Set saveUpdater boolean as "true" if you plan to continue training. Needed for updaters like Adam.
for (int i = 1; i < epochs + 1; i++) {
tunedModel.fit(trainIterator);
if (tunedModel.score() < lowest) {
lowest = tunedModel.score();
String modelFilename = new File(".").getAbsolutePath() + "/ImageClassifier_loss" + lowest + "_ep" + i + ".zip";
ModelSerializer.writeModel(tunedModel, modelFilename, true);
}
System.out.println(String.format("Completed epoch %d.", i));
}
//Load your model and continue to train, remember to set loadUpdater as "true".
tunedModel = ModelSerializer.restoreComputationGraph("ImageClassifier_loss0.38684930801391604_ep100_ResNet50.zip", true);
//Reuse the training block earlier.
int previousEpoch = 100
for (int i = 1+previousEpoch; i < epochs +previousEpoch+ 1; i++) {
tunedModel.fit(trainIterator);
if (tunedModel.score() < lowest) {
lowest = tunedModel.score();
String modelFilename = new File(".").getAbsolutePath() + "/ImageClassifier_loss" + lowest + "_ep" + i + ".zip";
ModelSerializer.writeModel(tunedModel, modelFilename, true);
}
System.out.println(String.format("Completed epoch %d.", i));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment