Skip to content

Instantly share code, notes, and snippets.

@fornarat
Created June 30, 2016 11:19
Show Gist options
  • Save fornarat/1999064823bb4e5eee93e45774ff6d45 to your computer and use it in GitHub Desktop.
Save fornarat/1999064823bb4e5eee93e45774ff6d45 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.convolution;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.impl.CSVRecordReader;
import org.canova.api.split.FileSplit;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
//import org.deeplearning4j.nn.conf.LearningRatePolicy;
import java.io.File;
/**
* Created by agibsonccc on 9/16/15.
*/
public class LenetMnistExampleCustom {
private static final Logger log = LoggerFactory.getLogger(LenetMnistExampleCustom.class);
public static void main(String[] args) throws Exception {
/*
int nChannels = 1;
int outputNum = 10;
int batchSize = 64;
int nEpochs = 10;
int iterations = 1;
int seed = 123;
*/
int iterations = 1;
int nChannels = 1;
int seed = 123;
double learningRate = 0.01;
int batchSize = 3500;
int nEpochs = 30;
// int numInputs = 2;
int outputNum = 2;
// int numHiddenNodes = 20;
log.info("Load data....");
// DataSetIterator dataSetIteratorTrain = new MnistDataSetIterator(batchSize,true,12345);
// DataSetIterator dataSetIteratorTest = new MnistDataSetIterator(batchSize,false,12345);
//Load the training data:
// RecordReader rr = new CSVRecordReader();
// rr.initialize(new FileSplit(new File("src/main/resources/classification/linear_data_train.csv")));
// org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTrain = new RecordReaderDataSetIterator(rr,batchSize,0,2);
RecordReader rrTrain = new CSVRecordReader();
rrTrain.initialize(new FileSplit(new File("src/main/resources/classification/train.csv")));
org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTrain = new RecordReaderDataSetIterator(rrTrain,batchSize,0,2);
//Load the test/evaluation data:
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File("src/main/resources/classification/test.csv")));
org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTest = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)
.learningRate(learningRate)//.biasLearningRate(0.02)
//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(nChannels)
.stride(1, 1)
.nOut(20)
// .nOut(outputNum)
.activation("identity")
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.nIn(nChannels)
.stride(1, 1)
.nOut(50)
.activation("identity")
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(2,2)
.build())
.layer(4, new DenseLayer.Builder().activation("relu")
.nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation("softmax")
.build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder,100, 90,1);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
log.info("Train model....");
model.setListeners(new ScoreIterationListener(1));
for( int i=0; i<nEpochs; i++ ) {
model.fit(dataSetIteratorTrain);
log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while(dataSetIteratorTest.hasNext()){
DataSet ds = dataSetIteratorTest.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
dataSetIteratorTest.reset();
}
log.info("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment