Skip to content

Instantly share code, notes, and snippets.

@sato-cloudian
Created December 16, 2015 03:30
Show Gist options
  • Select an option

  • Save sato-cloudian/5d3576388d1d81ed7a85 to your computer and use it in GitHub Desktop.

Select an option

Save sato-cloudian/5d3576388d1d81ed7a85 to your computer and use it in GitHub Desktop.
MyMLPBackpropIrisSplitExample experiments
package org.deeplearning4j.examples.mlp;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.weights.HistogramIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
/**
* Created by agibsonccc on 9/12/14.
*/
public class MyMLPBackpropIrisSplitExample {
private static Logger log = LoggerFactory.getLogger(MyMLPBackpropIrisSplitExample.class);
public static void main(String[] args) throws IOException {
// Customizing params
Nd4j.MAX_SLICES_TO_PRINT = 10;
Nd4j.MAX_ELEMENTS_PER_SLICE = 10;
final int numInputs = 4;
int outputNum = 3;
int numSamples = 150;
int batchSize = 150;
int iterations = 100;
int splitTrainNum = (int) (batchSize * .8);
long seed = 6;
int listenerFreq = iterations/5;
log.info("Load data....");
DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples);
DataSet next = iter.next();
next.normalizeZeroMeanZeroUnitVariance();
log.info("Split data....");
SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
DataSet train = testAndTrain.getTrain();
DataSet test = testAndTrain.getTest();
Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
// the best score in (1.191448, 1.6035) in non-split training and evaluation
// see MyMLPBackpropIrisExample
.learningRate(1.11245)
.regularization(true)
.list(2)
.layer(0, new DenseLayer.Builder()
.nIn(numInputs)
.nOut(3)
.activation("relu")
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunction.MCXENT)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.updater(Updater.SGD)
.nIn(3).nOut(outputNum).build())
.backprop(true)
.pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq)));
log.info("Train model....");
model.fit(train);
/*
while(iter.hasNext()) {
DataSet iris = iter.next();
iris.normalizeZeroMeanZeroUnitVariance();
model.fit(iris);
}
iter.reset();
*/
log.info("Evaluate weights....");
for(org.deeplearning4j.nn.api.Layer layer : model.getLayers()) {
INDArray w = layer.getParam(DefaultParamInitializer.WEIGHT_KEY);
log.info("Weights: " + w);
}
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
INDArray output = model.output(test.getFeatureMatrix());
for (int i = 0; i < output.rows(); i++) {
String actual = test.getLabels().getRow(i).toString().trim();
String predicted = output.getRow(i).toString().trim();
log.info("actual " + actual + " vs predicted " + predicted);
}
eval.eval(test.getLabels(), output);
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
learningRate = 1.6035
Examples labeled as 2 classified by model as 1: 5 times
Examples labeled as 2 classified by model as 2: 25 times
==========================Scores========================================
Accuracy: 0.8333
Precision: 1
Recall: 0.8333
F1 Score: 0.9090909090909091
===========================================================================
learningRate = 1.3
Examples labeled as 2 classified by model as 1: 4 times
Examples labeled as 2 classified by model as 2: 26 times
==========================Scores========================================
Accuracy: 0.8667
Precision: 1
Recall: 0.8667
F1 Score: 0.9285714285714286
===========================================================================
learningRate = 1.191448
Examples labeled as 2 classified by model as 1: 4 times
Examples labeled as 2 classified by model as 2: 26 times
==========================Scores========================================
Accuracy: 0.8667
Precision: 1
Recall: 0.8667
F1 Score: 0.9285714285714286
===========================================================================
learningRate = 1.15
Examples labeled as 2 classified by model as 1: 4 times
Examples labeled as 2 classified by model as 2: 26 times
==========================Scores========================================
Accuracy: 0.8667
Precision: 1
Recall: 0.8667
F1 Score: 0.9285714285714286
===========================================================================
learningRate = 1.125
Examples labeled as 2 classified by model as 1: 4 times
Examples labeled as 2 classified by model as 2: 26 times
==========================Scores========================================
Accuracy: 0.8667
Precision: 1
Recall: 0.8667
F1 Score: 0.9285714285714286
===========================================================================
learningRate = 1.1245
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learingRate = 1.124
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learingRate = 1.123
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learinigRate = 1.1225
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
leanringRate = 1.12
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 1.11
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 1.1
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 1.0
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 0.5
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 0.3
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 0.25
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 0.225
Examples labeled as 2 classified by model as 1: 3 times
Examples labeled as 2 classified by model as 2: 27 times
==========================Scores========================================
Accuracy: 0.9
Precision: 1
Recall: 0.9
F1 Score: 0.9473684210526316
===========================================================================
learningRate = 0.2
Examples labeled as 2 classified by model as 1: 4 times
Examples labeled as 2 classified by model as 2: 26 times
==========================Scores========================================
Accuracy: 0.8667
Precision: 1
Recall: 0.8667
F1 Score: 0.9285714285714286
===========================================================================
learningRate = 0.1
Examples labeled as 2 classified by model as 1: 11 times
Examples labeled as 2 classified by model as 2: 19 times
==========================Scores========================================
Accuracy: 0.6333
Precision: 1
Recall: 0.6333
F1 Score: 0.7755102040816326
===========================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment