Skip to content

Instantly share code, notes, and snippets.

@sato-cloudian
Last active December 29, 2015 04:16
Show Gist options
  • Select an option

  • Save sato-cloudian/f2851fab651b94deeba5 to your computer and use it in GitHub Desktop.

Select an option

Save sato-cloudian/f2851fab651b94deeba5 to your computer and use it in GitHub Desktop.
MyCNNIrisExample experiments
package org.deeplearning4j.examples.convolution;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Arrays;
import java.util.Random;
/**
* @author sonali
*/
public class MyCNNIrisExample {
private static Logger log = LoggerFactory.getLogger(CNNIrisExample.class);
public static void main(String[] args) {
final int numRows = 2;
final int numColumns = 2;
int nChannels = 1;
int outputNum = 3;
int numSamples = 150;
int batchSize = 150;
int iterations = 100;
int splitTrainNum = 100;
int seed = 123;
int listenerFreq = 1;
/**
*Set a neural network configuration with multiple layers
*/
log.info("Load data....");
DataSetIterator irisIter = new IrisDataSetIterator(numSamples, batchSize);
DataSet iris = irisIter.next();
iris.normalizeZeroMeanZeroUnitVariance();
System.out.println("Loaded " + iris.labelCounts());
Nd4j.shuffle(iris.getFeatureMatrix(), new Random(seed), 1);
Nd4j.shuffle(iris.getLabels(),new Random(seed),1);
SplitTestAndTrain trainTest = iris.splitTestAndTrain(splitTrainNum, new Random(seed));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(1.0) // default
.regularization(true)
.list(2)
.layer(0, new ConvolutionLayer.Builder(new int[]{1, 1})
.nIn(nChannels)
.nOut(6)
.activation("relu")
.weightInit(WeightInit.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.nOut(outputNum)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.updater(Updater.SGD)
.build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder,numRows,numColumns,nChannels);
MultiLayerConfiguration conf = builder.build();
log.info("Build model....");
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq)));
log.info("Train model....");
System.out.println("Training on " + trainTest.getTrain().labelCounts());
model.fit(trainTest.getTrain());
log.info("Evaluate weights....");
for(org.deeplearning4j.nn.api.Layer layer : model.getLayers()) {
INDArray w = layer.getParam(DefaultParamInitializer.WEIGHT_KEY);
log.info("Weights: " + w);
}
log.info("Evaluate model....");
System.out.println("Training on " + trainTest.getTest().labelCounts());
Evaluation eval = new Evaluation(outputNum);
INDArray output = model.output(trainTest.getTest().getFeatureMatrix());
eval.eval(trainTest.getTest().getLabels(), output);
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
nOut = 3
(iterations = 100)
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 17 times
Examples labeled as 1 classified by model as 2: 5 times
Examples labeled as 2 classified by model as 1: 1 times
Examples labeled as 2 classified by model as 2: 16 times
==========================Scores========================================
Accuracy: 0.88
Precision: 0.9021
Recall: 0.9046
F1 Score: 0.9033737367410006
===========================================================================
(iterations = 1000)
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 21 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 15 times
==========================Scores========================================
Accuracy: 0.94
Precision: 0.9502
Recall: 0.9456
F1 Score: 0.9479015228746314
===========================================================================
nOut = 6
(iterations = 100)
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 18 times
Examples labeled as 1 classified by model as 2: 4 times
Examples labeled as 2 classified by model as 1: 1 times
Examples labeled as 2 classified by model as 2: 16 times
==========================Scores========================================
Accuracy: 0.9
Precision: 0.9158
Recall: 0.9198
F1 Score: 0.9177834340212825
===========================================================================
(iterations = 100, learningRate = 1.0)
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 22 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 15 times
==========================Scores========================================
Accuracy: 0.96
Precision: 0.9722
Recall: 0.9608
F1 Score: 0.9664694280078895
===========================================================================
(iterations = 1000)
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 21 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 2 classified by model as 1: 1 times
Examples labeled as 2 classified by model as 2: 16 times
==========================Scores========================================
Accuracy: 0.96
Precision: 0.9652
Recall: 0.9652
F1 Score: 0.9652406417112299
===========================================================================
nOut = 9
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 18 times
Examples labeled as 1 classified by model as 2: 4 times
Examples labeled as 2 classified by model as 1: 1 times
Examples labeled as 2 classified by model as 2: 16 times
==========================Scores========================================
Accuracy: 0.9
Precision: 0.9158
Recall: 0.9198
F1 Score: 0.9177834340212825
===========================================================================
nOut = 12
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 19 times
Examples labeled as 1 classified by model as 2: 3 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 15 times
==========================Scores========================================
Accuracy: 0.9
Precision: 0.9127
Recall: 0.9153
F1 Score: 0.9140121966319961
===========================================================================
nOut = 100
Examples labeled as 0 classified by model as 0: 11 times
Examples labeled as 1 classified by model as 1: 19 times
Examples labeled as 1 classified by model as 2: 3 times
Examples labeled as 2 classified by model as 1: 2 times
Examples labeled as 2 classified by model as 2: 15 times
==========================Scores========================================
Accuracy: 0.9
Precision: 0.9127
Recall: 0.9153
F1 Score: 0.9140121966319961
===========================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment