Created
February 18, 2016 19:42
-
-
Save Tachyon5/b6c254529753dd280f01 to your computer and use it in GitHub Desktop.
I want to add a HistogramIterationListener to the following but can't seem to get the UI running.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.autoencoder; | |
import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.AutoEncoder; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.api.IterationListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.weights.HistogramIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.Arrays; | |
import java.util.Collections; | |
/** | |
* Created by agibsonccc on 9/11/14. | |
*/ | |
public class StackedAutoEncoderMnistExample { | |
private static Logger log = LoggerFactory.getLogger(StackedAutoEncoderMnistExample.class); | |
public static void main(String[] args) throws Exception { | |
final int numRows = 28; | |
final int numColumns = 28; | |
int outputNum = 10; | |
int numSamples = 60000; | |
int batchSize = 100; | |
int iterations = 10; | |
int seed = 123; | |
int listenerFreq = batchSize / 5; | |
log.info("Load data...."); | |
DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true); | |
log.info("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(1.0) | |
.iterations(iterations) | |
.momentum(0.5) | |
.momentumAfter(Collections.singletonMap(3, 0.9)) | |
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) | |
.list(4) | |
.layer(0, new AutoEncoder.Builder().nIn(numRows * numColumns).nOut(500) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.corruptionLevel(0.3) | |
.build()) | |
.layer(1, new AutoEncoder.Builder().nIn(500).nOut(250) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.corruptionLevel(0.3) | |
.build()) | |
.layer(2, new AutoEncoder.Builder().nIn(250).nOut(200) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.corruptionLevel(0.3) | |
.build()) | |
.layer(3, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax") | |
.nIn(200).nOut(outputNum).build()) | |
.pretrain(true).backprop(false) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq))); | |
model.setListeners(Arrays.asList(new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq))); | |
log.info("Train model...."); | |
model.fit(iter); // achieves end to end pre-training | |
log.info("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
DataSetIterator testIter = new MnistDataSetIterator(100,10000); | |
while(testIter.hasNext()) { | |
DataSet testMnist = testIter.next(); | |
INDArray predict2 = model.output(testMnist.getFeatureMatrix()); | |
eval.eval(testMnist.getLabels(), predict2); | |
} | |
log.info(eval.stats()); | |
log.info("****************Example finished********************"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment