-
-
Save gembin/050a8995ac9ac15bf43ddca6f2d8866c to your computer and use it in GitHub Desktop.
Examples of DL4J's Keras model import syntax (assumes Keras Functional API models and DL4J ComputationGraph)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.nn.modelimport.keras; | |
import org.deeplearning4j.nn.api.Layer; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Map; | |
public class KerasImportVgg16 { | |
private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class); | |
public static void main(String[] args) throws Exception { | |
String modelJsonFilename = "PATH TO EXPORTED JSON FILE"; | |
String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE"; | |
String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE"; | |
boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs | |
//will throw an exception or just generate a warning. | |
/* Import VGG 16 model from separate model config JSON and weights HDF5 files. | |
* Will not include loss layer or training configuration. | |
*/ | |
// Static helper from KerasModelImport | |
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
model = new KerasModel.ModelBuilder() | |
.modelJsonFilename(modelJsonFilename) | |
.weightsHdf5Filename(weightsHdf5Filename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraph(); | |
/* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */ | |
// Static helper from KerasModelImport | |
model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
model = new KerasModel.ModelBuilder() | |
.modelHdf5Filename(modelHdf5Filename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraph(); | |
/* Import VGG 16 model config from model config JSON. Will not include loss | |
* layer or training configuration. | |
*/ | |
// Static helper from KerasModelImport | |
ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig); | |
// KerasModel builder pattern | |
config = new KerasModel.ModelBuilder() | |
.modelJsonFilename(modelJsonFilename) | |
.enforceTrainingConfig(enforceTrainingConfig) | |
.buildModel() | |
.getComputationGraphConfiguration(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment