Last active
May 30, 2016 14:54
-
-
Save junyongyou/28cbca140337d1e054f05d61f7240084 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.convolution; | |
import org.canova.api.records.reader.RecordReader; | |
import org.canova.api.split.FileSplit; | |
import org.canova.image.loader.BaseImageLoader; | |
import org.canova.image.loader.NativeImageLoader; | |
import org.canova.image.recordreader.ImageRecordReader; | |
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.*; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.util.Arrays; | |
import java.util.List; | |
/** | |
* Created by junyong on 30.05.2016. | |
*/ | |
public class ImageClassificationExample | |
{ | |
public static void main(String[] args) throws Exception { | |
int imageWidth = 600; | |
int imageHeight = 600; | |
int channels = 3; | |
// create dataset | |
// Directory which has 1 sub-directory with images for each category you have. | |
File directory = new File("D:\\training_set"); | |
int batchSize = 100; | |
boolean appendLabels = true; | |
List<String> labels = Arrays.asList(directory.list()); | |
int numLabels = labels.size(); | |
RecordReader recordReader = new ImageRecordReader(imageHeight, imageWidth, channels, appendLabels); | |
recordReader.initialize(new FileSplit(directory, BaseImageLoader.ALLOWED_FORMATS)); | |
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, -1, numLabels); | |
// setup model | |
MultiLayerNetwork model = new MultiLayerNetwork(buildConfig(imageWidth, imageHeight, channels, numLabels)); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(1)); | |
// train model | |
int epochs = 5; | |
for (int i = 0; i < epochs; i++) { | |
dataIter.reset(); | |
model.fit(dataIter); | |
} | |
// test model | |
dataIter.reset(); | |
Evaluation eval = new Evaluation(dataIter.getLabels()); | |
while (dataIter.hasNext()) { | |
DataSet next = dataIter.next(); | |
INDArray prediction = model.output(next.getFeatureMatrix()); | |
eval.eval(next.getLabels(), prediction); | |
} | |
System.out.println(eval.stats()); | |
// predict new image | |
File imageToPredict = new File("D:\\classified images\\100079.jpg"); | |
NativeImageLoader imageLoader = new NativeImageLoader(imageHeight, imageWidth, channels); | |
INDArray imageVector = imageLoader.asRowVector(imageToPredict); | |
INDArray prediction = model.output(imageVector); | |
System.out.println("done"); | |
// prediction contains one float for every label you have, sums up to 1 | |
} | |
static private MultiLayerConfiguration buildConfig(int imageWidth, int imageHeight, int channels, int numOfClasses) { | |
int seed = 123; | |
int iterations = 1; | |
WeightInit weightInit = WeightInit.XAVIER; | |
String activation = "relu"; | |
Updater updater = Updater.NESTEROVS; | |
double lr = 1e-3; | |
double mu = 0.9; | |
double l2 = 5e-4; | |
boolean regularization = true; | |
SubsamplingLayer.PoolingType poolingType = SubsamplingLayer.PoolingType.MAX; | |
double nonZeroBias = 1; | |
double dropOut = 0.5; | |
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).activation(activation).weightInit(weightInit) | |
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(lr).momentum(mu) | |
.regularization(regularization).l2(l2).updater(updater).useDropConnect(true) | |
// AlexNet | |
.list().layer(0, new ConvolutionLayer.Builder(new int[] { 11, 11 }, new int[] { 4, 4 }, new int[] { 3, 3 }).name("cnn1").nIn(channels).nOut(96).build()) | |
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build()) | |
.layer(2, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool1").build()) | |
.layer(3, new ConvolutionLayer.Builder(new int[] { 5, 5 }, new int[] { 1, 1 }, new int[] { 2, 2 }).name("cnn2").nOut(256).biasInit(nonZeroBias).build()) | |
.layer(4, new LocalResponseNormalization.Builder().name("lrn2").k(2).n(5).alpha(1e-4).beta(0.75).build()) | |
.layer(5, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool2").build()) | |
.layer(6, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn3").nOut(384).build()) | |
.layer(7, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn4").nOut(384).biasInit(nonZeroBias).build()) | |
.layer(8, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn5").nOut(256).biasInit(nonZeroBias).build()) | |
.layer(9, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool3").build()) | |
.layer(10, new DenseLayer.Builder().name("ffn1").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).build()) | |
.layer(11, new DenseLayer.Builder().name("ffn2").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).build()) | |
.layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).name("output").nOut(numOfClasses).activation("softmax").build()).backprop(true).pretrain(false) | |
.cnnInputSize(imageHeight, imageWidth, channels); | |
return builder.build(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment