Created
December 16, 2015 03:30
-
-
Save sato-cloudian/5d3576388d1d81ed7a85 to your computer and use it in GitHub Desktop.
MyMLPBackpropIrisSplitExample experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.deeplearning4j.examples.mlp; | |
| import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
| import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; | |
| import org.deeplearning4j.eval.Evaluation; | |
| import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
| import org.deeplearning4j.nn.conf.Updater; | |
| import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
| import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
| import org.deeplearning4j.nn.params.DefaultParamInitializer; | |
| import org.deeplearning4j.nn.weights.WeightInit; | |
| import org.deeplearning4j.optimize.api.IterationListener; | |
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
| import org.deeplearning4j.ui.weights.HistogramIterationListener; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.dataset.DataSet; | |
| import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
| import org.nd4j.linalg.factory.Nd4j; | |
| import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import java.io.IOException; | |
| import java.util.Arrays; | |
| import java.util.Random; | |
| /** | |
| * Created by agibsonccc on 9/12/14. | |
| */ | |
| public class MyMLPBackpropIrisSplitExample { | |
| private static Logger log = LoggerFactory.getLogger(MyMLPBackpropIrisSplitExample.class); | |
| public static void main(String[] args) throws IOException { | |
| // Customizing params | |
| Nd4j.MAX_SLICES_TO_PRINT = 10; | |
| Nd4j.MAX_ELEMENTS_PER_SLICE = 10; | |
| final int numInputs = 4; | |
| int outputNum = 3; | |
| int numSamples = 150; | |
| int batchSize = 150; | |
| int iterations = 100; | |
| int splitTrainNum = (int) (batchSize * .8); | |
| long seed = 6; | |
| int listenerFreq = iterations/5; | |
| log.info("Load data...."); | |
| DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples); | |
| DataSet next = iter.next(); | |
| next.normalizeZeroMeanZeroUnitVariance(); | |
| log.info("Split data...."); | |
| SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed)); | |
| DataSet train = testAndTrain.getTrain(); | |
| DataSet test = testAndTrain.getTest(); | |
| Nd4j.ENFORCE_NUMERICAL_STABILITY = true; | |
| log.info("Build model...."); | |
| MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
| .seed(seed) | |
| .iterations(iterations) | |
| .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
| // the best score in (1.191448, 1.6035) in non-split training and evaluation | |
| // see MyMLPBackpropIrisExample | |
| .learningRate(1.11245) | |
| .regularization(true) | |
| .list(2) | |
| .layer(0, new DenseLayer.Builder() | |
| .nIn(numInputs) | |
| .nOut(3) | |
| .activation("relu") | |
| .weightInit(WeightInit.XAVIER) | |
| .build()) | |
| .layer(1, new OutputLayer.Builder(LossFunction.MCXENT) | |
| .weightInit(WeightInit.XAVIER) | |
| .activation("softmax") | |
| .updater(Updater.SGD) | |
| .nIn(3).nOut(outputNum).build()) | |
| .backprop(true) | |
| .pretrain(false) | |
| .build(); | |
| MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
| model.init(); | |
| model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq))); | |
| log.info("Train model...."); | |
| model.fit(train); | |
| /* | |
| while(iter.hasNext()) { | |
| DataSet iris = iter.next(); | |
| iris.normalizeZeroMeanZeroUnitVariance(); | |
| model.fit(iris); | |
| } | |
| iter.reset(); | |
| */ | |
| log.info("Evaluate weights...."); | |
| for(org.deeplearning4j.nn.api.Layer layer : model.getLayers()) { | |
| INDArray w = layer.getParam(DefaultParamInitializer.WEIGHT_KEY); | |
| log.info("Weights: " + w); | |
| } | |
| log.info("Evaluate model...."); | |
| Evaluation eval = new Evaluation(outputNum); | |
| INDArray output = model.output(test.getFeatureMatrix()); | |
| for (int i = 0; i < output.rows(); i++) { | |
| String actual = test.getLabels().getRow(i).toString().trim(); | |
| String predicted = output.getRow(i).toString().trim(); | |
| log.info("actual " + actual + " vs predicted " + predicted); | |
| } | |
| eval.eval(test.getLabels(), output); | |
| log.info(eval.stats()); | |
| log.info("****************Example finished********************"); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| learningRate = 1.6035 | |
| Examples labeled as 2 classified by model as 1: 5 times | |
| Examples labeled as 2 classified by model as 2: 25 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8333 | |
| Precision: 1 | |
| Recall: 0.8333 | |
| F1 Score: 0.9090909090909091 | |
| =========================================================================== | |
| learningRate = 1.3 | |
| Examples labeled as 2 classified by model as 1: 4 times | |
| Examples labeled as 2 classified by model as 2: 26 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8667 | |
| Precision: 1 | |
| Recall: 0.8667 | |
| F1 Score: 0.9285714285714286 | |
| =========================================================================== | |
| learningRate = 1.191448 | |
| Examples labeled as 2 classified by model as 1: 4 times | |
| Examples labeled as 2 classified by model as 2: 26 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8667 | |
| Precision: 1 | |
| Recall: 0.8667 | |
| F1 Score: 0.9285714285714286 | |
| =========================================================================== | |
| learningRate = 1.15 | |
| Examples labeled as 2 classified by model as 1: 4 times | |
| Examples labeled as 2 classified by model as 2: 26 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8667 | |
| Precision: 1 | |
| Recall: 0.8667 | |
| F1 Score: 0.9285714285714286 | |
| =========================================================================== | |
| learningRate = 1.125 | |
| Examples labeled as 2 classified by model as 1: 4 times | |
| Examples labeled as 2 classified by model as 2: 26 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8667 | |
| Precision: 1 | |
| Recall: 0.8667 | |
| F1 Score: 0.9285714285714286 | |
| =========================================================================== | |
| learningRate = 1.1245 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learingRate = 1.124 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learingRate = 1.123 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learinigRate = 1.1225 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| leanringRate = 1.12 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 1.11 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 1.1 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 1.0 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 0.5 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 0.3 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 0.25 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 0.225 | |
| Examples labeled as 2 classified by model as 1: 3 times | |
| Examples labeled as 2 classified by model as 2: 27 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.9 | |
| Precision: 1 | |
| Recall: 0.9 | |
| F1 Score: 0.9473684210526316 | |
| =========================================================================== | |
| learningRate = 0.2 | |
| Examples labeled as 2 classified by model as 1: 4 times | |
| Examples labeled as 2 classified by model as 2: 26 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.8667 | |
| Precision: 1 | |
| Recall: 0.8667 | |
| F1 Score: 0.9285714285714286 | |
| =========================================================================== | |
| learningRate = 0.1 | |
| Examples labeled as 2 classified by model as 1: 11 times | |
| Examples labeled as 2 classified by model as 2: 19 times | |
| ==========================Scores======================================== | |
| Accuracy: 0.6333 | |
| Precision: 1 | |
| Recall: 0.6333 | |
| F1 Score: 0.7755102040816326 | |
| =========================================================================== |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment