Created
April 5, 2020 08:54
-
-
Save yptheangel/d3c26b82fd7aaedbf179ba0c984fdba7 to your computer and use it in GitHub Desktop.
Load DL4J model checkpoint and continue training.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Train your model and save the model if the training loss is better than the previous. Set saveUpdater boolean as "true" if you plan to continue training. Needed for updaters like Adam. | |
for (int i = 1; i < epochs + 1; i++) { | |
tunedModel.fit(trainIterator); | |
if (tunedModel.score() < lowest) { | |
lowest = tunedModel.score(); | |
String modelFilename = new File(".").getAbsolutePath() + "/ImageClassifier_loss" + lowest + "_ep" + i + ".zip"; | |
ModelSerializer.writeModel(tunedModel, modelFilename, true); | |
} | |
System.out.println(String.format("Completed epoch %d.", i)); | |
} | |
//Load your model and continue to train, remember to set loadUpdater as "true". | |
tunedModel = ModelSerializer.restoreComputationGraph("ImageClassifier_loss0.38684930801391604_ep100_ResNet50.zip", true); | |
//Reuse the training block earlier. | |
int previousEpoch = 100 | |
for (int i = 1+previousEpoch; i < epochs +previousEpoch+ 1; i++) { | |
tunedModel.fit(trainIterator); | |
if (tunedModel.score() < lowest) { | |
lowest = tunedModel.score(); | |
String modelFilename = new File(".").getAbsolutePath() + "/ImageClassifier_loss" + lowest + "_ep" + i + ".zip"; | |
ModelSerializer.writeModel(tunedModel, modelFilename, true); | |
} | |
System.out.println(String.format("Completed epoch %d.", i)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment