Skip to content

Instantly share code, notes, and snippets.

@sato-cloudian
Created December 22, 2015 07:28
Show Gist options
  • Select an option

  • Save sato-cloudian/a5438c80e69848121960 to your computer and use it in GitHub Desktop.

Select an option

Save sato-cloudian/a5438c80e69848121960 to your computer and use it in GitHub Desktop.
MyDBNMnistExample experimentations
package org.deeplearning4j.examples.deepbelief;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* Created by agibsonccc on 9/11/14.
*/
public class MyDBNMnistExample {
private static Logger log = LoggerFactory.getLogger(MyDBNMnistExample.class);
public static void main(String[] args) throws Exception {
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10;
int numSamples = 10000;
int batchSize = 500;
int iterations = 10;
int seed = 123;
int listenerFreq = batchSize / 10;
int splitTrainNum = (int) (batchSize*.8);
DataSet mnist;
SplitTestAndTrain trainTest;
DataSet trainInput;
List<INDArray> testInput = new ArrayList<>();
List<INDArray> testLabels = new ArrayList<>();
log.info("Load data....");
DataSetIterator mnistIter = new MnistDataSetIterator(batchSize,numSamples,true);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
//.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
//.gradientNormalizationThreshold(1.0)
.iterations(iterations)
.learningRate(1e-6)
.regularization(true)
//.momentum(0.5)
//.momentumAfter(Collections.singletonMap(3, 0.9))
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.list(2)
.layer(0, new RBM.Builder()
.nIn(numRows*numColumns)
.nOut(outputNum)
.weightInit(WeightInit.RELU)
.activation("relu")
.k(1) // default
.sparsity(0.0D) // default
.visibleUnit(RBM.VisibleUnit.BINARY) // default
.hiddenUnit(RBM.HiddenUnit.BINARY) // default
.lossFunction(LossFunction.RECONSTRUCTION_CROSSENTROPY) // default
.build())
.layer(1, new OutputLayer.Builder(LossFunction.MCXENT)
.activation("softmax")
.nIn(outputNum)
.nOut(outputNum).build())
.pretrain(true)
.backprop(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq)));
log.info("Train model....");
while(mnistIter.hasNext()) {
mnist = mnistIter.next();
trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
trainInput = trainTest.getTrain(); // get feature matrix and labels for training
testInput.add(trainTest.getTest().getFeatureMatrix());
testLabels.add(trainTest.getTest().getLabels());
model.fit(trainInput);
}
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
for(int i = 0; i < testInput.size(); i++) {
INDArray output = model.output(testInput.get(i));
eval.eval(testLabels.get(i), output);
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment