Skip to content

Instantly share code, notes, and snippets.

@geekprogramming
Created June 1, 2016 07:02
Show Gist options
  • Save geekprogramming/3b19de708f2f9f02d0df4a2d0c5770f9 to your computer and use it in GitHub Desktop.
Save geekprogramming/3b19de708f2f9f02d0df4a2d0c5770f9 to your computer and use it in GitHub Desktop.
public Model saveModel() throws IOException {
log.info("Save model............");
OutputStream fos = Files.newOutputStream(Paths.get(Config.FILE_PARAMS));
DataOutputStream dos = new DataOutputStream(fos);
Nd4j.write(model.params(), dos);
dos.flush();
dos.close();
FileUtils.writeStringToFile(new File(Config.FILE_MODEL), model.getLayerWiseConfigurations().toJson());
return this;
}
private MultiLayerNetwork loadModel() {
log.info("Load model...........");
try {
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(FileHelper.getFilePath(Config.FILE_MODEL))));
DataInputStream dis = new DataInputStream(new FileInputStream(FileHelper.getFilePath(Config.FILE_PARAMS)));
INDArray newParams = Nd4j.read(dis);
dis.close();
MultiLayerNetwork savedNetwork = new MultiLayerNetwork(confFromJson);
savedNetwork.init();
savedNetwork.setParameters(newParams);
System.out.println("Num params: "+savedNetwork.numParams());
return savedNetwork;
} catch (Exception ex) {
ex.printStackTrace();
}
return null;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment