Skip to content

Instantly share code, notes, and snippets.

@sato-cloudian
Last active December 18, 2015 06:44
Show Gist options
  • Select an option

  • Save sato-cloudian/0b11bc517b509cdc59ce to your computer and use it in GitHub Desktop.

Select an option

Save sato-cloudian/0b11bc517b509cdc59ce to your computer and use it in GitHub Desktop.
MyMLPMnistSingleLayerExample experiments(1 hidden layer)
784 => 400 => 90 => 10
[SGD]
==========================Scores========================================
Accuracy: 0.9105
Precision: 0.9119
Recall: 0.9086
F1 Score: 0.910244372714764
===========================================================================
[ADAGRAD]
==========================Scores========================================
Accuracy: 0.904
Precision: 0.9032
Recall: 0.9011
F1 Score: 0.9021182909434786
===========================================================================
[ADADELTA]
==========================Scores========================================
Accuracy: 0.904
Precision: 0.903
Recall: 0.9012
F1 Score: 0.9021191365508152
===========================================================================
[ADAM]
==========================Scores========================================
Accuracy: 0.903
Precision: 0.9021
Recall: 0.9001
F1 Score: 0.9011163063629107
===========================================================================
[NESTEROVS]
==========================Scores========================================
Accuracy: 0.9105
Precision: 0.9119
Recall: 0.9086
F1 Score: 0.9102523126318371
===========================================================================
[RMSPROP]
==========================Scores========================================
Accuracy: 0.905
Precision: 0.9051
Recall: 0.903
F1 Score: 0.9040834740468716
===========================================================================
784 => 400 => 90 => 10
Updater.NESTEROVS
[leakyrelu]
==========================Scores========================================
Accuracy: 0.912
Precision: 0.9133
Recall: 0.9102
F1 Score: 0.9117608412734705
===========================================================================
[WeightInit.RELU]
==========================Scores========================================
Accuracy: 0.9205
Precision: 0.9213
Recall: 0.9191
F1 Score: 0.9201900650263835
===========================================================================
[regularization=true]
==========================Scores========================================
Accuracy: 0.9105
Precision: 0.9119
Recall: 0.9086
F1 Score: 0.9102523126318371
===========================================================================
[ALL of the above]
==========================Scores========================================
Accuracy: 0.921
Precision: 0.9218
Recall: 0.9196
F1 Score: 0.9206599678042522
===========================================================================
package org.deeplearning4j.examples.mlp;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
/**
* Created by agibsonccc on 9/11/14.
*
* Diff from small single layer
*/
public class MyMLPMnistSingleLayerExample {
private static Logger log = LoggerFactory.getLogger(MyMLPMnistSingleLayerExample.class);
public static void main(String[] args) throws Exception {
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10;
int numSamples =10000;
int batchSize = 500;
int iterations = 10;
int seed = 123;
int listenerFreq = iterations/10;
int splitTrainNum = (int) (batchSize*.8);
DataSet mnist;
SplitTestAndTrain trainTest;
DataSet trainInput;
List<INDArray> testInput = new ArrayList<>();
List<INDArray> testLabels = new ArrayList<>();
log.info("Load data....");
DataSetIterator mnistIter = new MnistDataSetIterator(batchSize, numSamples,true);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(iterations)
//.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.learningRate(1e-1)
//.momentum(0.5)
//.momentumAfter(Collections.singletonMap(3, 0.9))
//.useDropConnect(true)
.list(3)
.layer(0, new DenseLayer.Builder()
.nIn(numRows * numColumns) // 28*28=784
.nOut(400)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new DenseLayer.Builder()
.nIn(400)
.nOut(88)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.build())
.layer(2, new OutputLayer.Builder(LossFunction.MCXENT)
.nIn(88)
.nOut(outputNum)
.activation("softmax")
.weightInit(WeightInit.XAVIER)
.updater(Updater.SGD)
.build())
.backprop(true)
.pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq)));
log.info("Train model....");
while(mnistIter.hasNext()) {
mnist = mnistIter.next();
trainTest = mnist.splitTestAndTrain(splitTrainNum, new Random(seed)); // train set that is the result
trainInput = trainTest.getTrain(); // get feature matrix and labels for training
testInput.add(trainTest.getTest().getFeatureMatrix());
testLabels.add(trainTest.getTest().getLabels());
model.fit(trainInput);
}
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
for(int i = 0; i < testInput.size(); i++) {
INDArray output = model.output(testInput.get(i));
eval.eval(testLabels.get(i), output);
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
784 => 400 => 300 => 10
==========================Scores========================================
Accuracy: 0.907
Precision: 0.9069
Recall: 0.905
F1 Score: 0.905947279966959
===========================================================================
784 => 400 => 200 => 10
==========================Scores========================================
Accuracy: 0.9005
Precision: 0.9014
Recall: 0.8981
F1 Score: 0.8997049679660892
===========================================================================
784 => 400 => 110 => 10
==========================Scores========================================
Accuracy: 0.902
Precision: 0.9021
Recall: 0.8997
F1 Score: 0.9008750162695861
===========================================================================
784 => 400 => 100 => 10
==========================Scores========================================
Accuracy: 0.9065
Precision: 0.9074
Recall: 0.9046
F1 Score: 0.9060417265046306
===========================================================================
784 => 400 => 93 => 10
==========================Scores========================================
Accuracy: 0.908
Precision: 0.9078
Recall: 0.9063
F1 Score: 0.9070248307551632
===========================================================================
784 => 400 => 90 => 10
==========================Scores========================================
Accuracy: 0.9105
Precision: 0.9119
Recall: 0.9086
F1 Score: 0.910244372714764
===========================================================================
784 => 400 => 88 => 10
==========================Scores========================================
Accuracy: 0.9065
Precision: 0.907
Recall: 0.9043
F1 Score: 0.9056646604023598
===========================================================================
784 => 400 => 80 => 10
==========================Scores========================================
Accuracy: 0.9065
Precision: 0.9074
Recall: 0.9045
F1 Score: 0.9059277269050349
===========================================================================
784 => 400 => 50 => 10
==========================Scores========================================
Accuracy: 0.903
Precision: 0.9025
Recall: 0.9004
F1 Score: 0.901468182506639
===========================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment