Created
June 13, 2018 16:25
-
-
Save ragnarok22/9803f534aa70019d39aa48532a855d00 to your computer and use it in GitHub Desktop.
sample code in DL4J and MNIST dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
String dataset_home = System.getProperty("user.home") + File.separator + "datasets" + File.separator; | |
String mnist_home = dataset_home + "mnist_png" + File.separator; | |
String trainPath = mnist_home + "training"; | |
String testPath = mnist_home + "testing"; | |
int numRows = 28; // height | |
int numColumns = 28; // width | |
int channels = 1; // depth | |
int outputNum = 10; | |
int batchSize = 128; | |
int seed = 123; | |
Random rng = new Random(seed); | |
int numEpochs = 1; | |
String pathToSaveModel = "D:\\"; | |
boolean saveModel = true; | |
File trainData = new File(trainPath); | |
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
ImageRecordReader recordReader = new ImageRecordReader(numRows, numColumns, channels, labelMaker); | |
recordReader.initialize(train); | |
DataSetIterator mnistTrain = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum); | |
DataNormalization scaler = new ImagePreProcessingScaler(0, 1); | |
scaler.fit(mnistTrain); | |
mnistTrain.setPreProcessor(scaler); | |
System.out.println("contruyendo el modelo"); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(new Nesterovs()) | |
// .updater(new Nesterovs(0.006, 0.9)) | |
.l2(1e-4) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(numRows * numColumns) | |
.nOut(100) | |
.activation(Activation.RELU) | |
.weightInit(WeightInit.XAVIER) | |
.build() | |
) | |
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.nOut(outputNum) | |
.activation(Activation.SOFTMAX) | |
.weightInit(WeightInit.XAVIER) | |
.build() | |
) | |
.pretrain(false).backprop(true) | |
.setInputType(InputType.convolutional(numRows, numColumns, channels)) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(100)); | |
System.out.println("entrenando el modelo"); | |
for (int i = 0; i < numEpochs; i++) { | |
model.fit(mnistTrain); | |
} | |
if (saveModel) | |
model.save(new File(pathToSaveModel + "modelDL4J.zip")); | |
System.out.println("evaluando el modelo"); | |
File testData = new File(testPath); | |
FileSplit test = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, rng); | |
recordReader.reset(); | |
recordReader.initialize(test); | |
DataSetIterator mnistTest = new RecordReaderDataSetIterator(recordReader, batchSize, 1, outputNum); | |
scaler.fit(mnistTest); | |
mnistTest.setPreProcessor(scaler); | |
Evaluation evaluation = new Evaluation(outputNum); | |
while (mnistTest.hasNext()) { | |
DataSet next = mnistTest.next(); | |
INDArray output = model.output(next.getFeatureMatrix()); | |
evaluation.eval(next.getLabels(), output); | |
} | |
System.out.println(evaluation.stats()); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment