Created
June 1, 2016 07:02
-
-
Save geekprogramming/3b19de708f2f9f02d0df4a2d0c5770f9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public Model saveModel() throws IOException { | |
log.info("Save model............"); | |
OutputStream fos = Files.newOutputStream(Paths.get(Config.FILE_PARAMS)); | |
DataOutputStream dos = new DataOutputStream(fos); | |
Nd4j.write(model.params(), dos); | |
dos.flush(); | |
dos.close(); | |
FileUtils.writeStringToFile(new File(Config.FILE_MODEL), model.getLayerWiseConfigurations().toJson()); | |
return this; | |
} | |
private MultiLayerNetwork loadModel() { | |
log.info("Load model..........."); | |
try { | |
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(FileHelper.getFilePath(Config.FILE_MODEL)))); | |
DataInputStream dis = new DataInputStream(new FileInputStream(FileHelper.getFilePath(Config.FILE_PARAMS))); | |
INDArray newParams = Nd4j.read(dis); | |
dis.close(); | |
MultiLayerNetwork savedNetwork = new MultiLayerNetwork(confFromJson); | |
savedNetwork.init(); | |
savedNetwork.setParameters(newParams); | |
System.out.println("Num params: "+savedNetwork.numParams()); | |
return savedNetwork; | |
} catch (Exception ex) { | |
ex.printStackTrace(); | |
} | |
return null; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment