Last active
August 31, 2016 16:58
-
-
Save P8H/f9576e994bfbfa7ab91e76b593b52c8b to your computer and use it in GitHub Desktop.
Autoencoder with evaluation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.RBM; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.optimize.api.IterationListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.util.Collections; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
/** | |
* ***** NOTE: This example has not been tuned. It requires additional work to | |
* produce sensible results ***** | |
* | |
* @author Adam Gibson | |
*/ | |
public class DeepAutoEncoderExample { | |
private static Logger log = LoggerFactory.getLogger(DeepAutoEncoderExample.class); | |
public static void main(String[] args) throws Exception { | |
final int numRows = 28; | |
final int numColumns = 28; | |
int seed = 123; | |
int numSamples = MnistDataFetcher.NUM_EXAMPLES; | |
int batchSize = 1000; | |
int iterations = 1; | |
int listenerFreq = iterations / 5; | |
log.info("Load data...."); | |
DataSetIterator iter = new MnistDataSetIterator(batchSize, numSamples, true); | |
log.info("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.iterations(iterations) | |
.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) | |
.list() | |
.layer(0, new RBM.Builder().nIn(numRows * numColumns).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(1, new RBM.Builder().nIn(1000).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(2, new RBM.Builder().nIn(500).nOut(250).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(3, new RBM.Builder().nIn(250).nOut(100).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(4, new RBM.Builder().nIn(100).nOut(30).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //encoding stops | |
.layer(5, new RBM.Builder().nIn(30).nOut(100).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) //decoding starts | |
.layer(6, new RBM.Builder().nIn(100).nOut(250).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(7, new RBM.Builder().nIn(250).nOut(500).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(8, new RBM.Builder().nIn(500).nOut(1000).lossFunction(LossFunctions.LossFunction.RMSE_XENT).build()) | |
.layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT).nIn(1000).nOut(numRows * numColumns).build()) | |
.pretrain(true).backprop(true) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq))); | |
log.info("Train model...."); | |
while (iter.hasNext()) { | |
DataSet next = iter.next(); | |
model.fit(new DataSet(next.getFeatureMatrix(), next.getFeatureMatrix())); | |
} | |
System.out.println("Evaluate model...."); | |
iter.reset(); | |
Evaluation eval = new Evaluation(numRows*numColumns); | |
while (iter.hasNext()) { | |
DataSet next = iter.next(); | |
INDArray features = next.getFeatureMatrix(); | |
INDArray predicted = model.output(features, false); | |
eval.eval(features, predicted); | |
//check autoencode requirement input == output | |
} | |
System.out.println(eval.stats()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment