Last active
February 24, 2016 18:05
-
-
Save Tachyon5/117fa2117516d52ef126 to your computer and use it in GitHub Desktop.
Can you explain the .hiddenUnit and .visibleUnit attributes?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.deepbelief; | |
import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.RBM; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.api.IterationListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.weights.HistogramIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.Collections; | |
/** | |
* Created by agibsonccc on 9/11/14. | |
*/ | |
public class DBNMnistFullExample { | |
private static Logger log = LoggerFactory.getLogger(DBNMnistFullExample.class); | |
public static void main(String[] args) throws Exception { | |
final int numRows = 28; | |
final int numColumns = 28; | |
int outputNum = 10; | |
int numSamples = 60000; | |
int batchSize = 1568; | |
int iterations = 10; | |
int seed = 123; | |
int listenerFreq = batchSize / 5; | |
log.info("Load data...."); | |
DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true); | |
log.info("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(1.0) | |
.iterations(iterations) | |
.momentum(0.5) | |
.momentumAfter(Collections.singletonMap(3, 0.9)) | |
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) | |
.list(4) | |
.layer(0, new RBM.Builder().nIn(numRows*numColumns).nOut(500) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.visibleUnit(RBM.VisibleUnit.BINARY) | |
.hiddenUnit(RBM.HiddenUnit.BINARY) | |
.activation("relu") | |
.build()) | |
.layer(1, new RBM.Builder().nIn(500).nOut(250) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.visibleUnit(RBM.VisibleUnit.BINARY) | |
.hiddenUnit(RBM.HiddenUnit.BINARY) | |
.activation("relu") | |
.build()) | |
.layer(2, new RBM.Builder().nIn(250).nOut(200) | |
.weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT) | |
.visibleUnit(RBM.VisibleUnit.BINARY) | |
.hiddenUnit(RBM.HiddenUnit.BINARY) | |
.activation("relu") | |
.build()) | |
.layer(3, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD).activation("softmax") | |
.nIn(200).nOut(outputNum).build()) | |
.pretrain(true).backprop(false) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
// model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq))); | |
model.setListeners( new HistogramIterationListener(listenerFreq)); | |
log.info("Train model...."); | |
model.fit(iter); // achieves end to end pre-training | |
log.info("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
DataSetIterator testIter = new MnistDataSetIterator(100,10000); | |
while(testIter.hasNext()) { | |
DataSet testMnist = testIter.next(); | |
INDArray predict2 = model.output(testMnist.getFeatureMatrix()); | |
eval.eval(testMnist.getLabels(), predict2); | |
} | |
log.info(eval.stats()); | |
log.info("****************Example finished********************"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment