Last active
January 26, 2016 23:35
-
-
Save sato-cloudian/d30adda1f35ad4d1809d to your computer and use it in GitHub Desktop.
TrainPeople
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| WARN [2016-01-26 23:33:12,931] org.deeplearning4j.optimize.solvers.BaseOptimizer: Objective function automatically set to minimize. Set stepFunction in neural net configuration to change default settings. | |
| Exception in thread "main" java.lang.IllegalArgumentException: Shapes do not match: x.shape=[10, 17], y.shape=[10, 549] | |
| at org.nd4j.linalg.api.parallel.tasks.cpu.CPUTaskFactory.getTransformAction(CPUTaskFactory.java:92) | |
| at org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.doTransformOp(DefaultOpExecutioner.java:409) | |
| at org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner.exec(DefaultOpExecutioner.java:62) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.subi(BaseNDArray.java:2660) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.subi(BaseNDArray.java:2641) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.sub(BaseNDArray.java:2419) | |
| at org.deeplearning4j.nn.layers.BaseOutputLayer.getGradientsAndDelta(BaseOutputLayer.java:154) | |
| at org.deeplearning4j.nn.layers.BaseOutputLayer.backpropGradient(BaseOutputLayer.java:133) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.calcBackpropGradients(MultiLayerNetwork.java:1224) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.backprop(MultiLayerNetwork.java:1178) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:1753) | |
| at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:132) | |
| at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:56) | |
| at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1497) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1529) | |
| at org.deeplearning4j.examples.convolution.TrainPeople.execute(TrainPeople.java:147) | |
| at org.deeplearning4j.examples.convolution.TrainPeople.main(TrainPeople.java:169) | |
| at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) | |
| at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) | |
| at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) | |
| at java.lang.reflect.Method.invoke(Method.java:606) | |
| at com.intellij.rt.execution.application.AppMain.main(AppMain.java:144) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package org.deeplearning4j.examples.convolution; | |
| import org.canova.api.records.reader.RecordReader; | |
| import org.canova.api.split.FileSplit; | |
| import org.canova.image.recordreader.ImageRecordReader; | |
| import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator; | |
| import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
| import org.deeplearning4j.eval.Evaluation; | |
| import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
| import org.deeplearning4j.nn.conf.Updater; | |
| import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
| import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
| import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
| import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | |
| import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup; | |
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
| import org.deeplearning4j.nn.weights.WeightInit; | |
| import org.deeplearning4j.optimize.api.IterationListener; | |
| import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
| import org.deeplearning4j.ui.weights.HistogramIterationListener; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.api.rng.Random; | |
| import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
| import org.nd4j.linalg.dataset.api.DataSet; | |
| import org.nd4j.linalg.factory.Nd4j; | |
| import org.nd4j.linalg.lossfunctions.LossFunctions; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import java.io.File; | |
| import java.io.IOException; | |
| import java.util.ArrayList; | |
| import java.util.Arrays; | |
| import java.util.List; | |
| /** | |
| * Created by tsato on 16/01/26. | |
| */ | |
| public class TrainPeople { | |
| private static final Logger log = LoggerFactory.getLogger(TrainPeople.class); | |
| private final File trainingFolder; | |
| public TrainPeople(File trainingFolder) { | |
| this.trainingFolder = trainingFolder; | |
| } | |
| private void execute() throws IOException{ | |
| // create labels | |
| int samples = 0; | |
| int outputs = 0; | |
| List<String> labels = new ArrayList<String>(); | |
| for (String labelName : this.trainingFolder.list()) { | |
| outputs++; | |
| System.out.println("generating labels for " + labelName); | |
| File labelFolder = new File(this.trainingFolder, labelName); | |
| for (String image : labelFolder.list()) { | |
| labels.add(labelName); | |
| samples++; | |
| log.info("added " + labelName + " on " + new File(labelFolder, image).getAbsolutePath()); | |
| } | |
| } | |
| log.info("outputs, samples = " + outputs + ", " + samples); | |
| // read images | |
| int width = 40; | |
| int height = 32; | |
| RecordReader recordReader = new ImageRecordReader(width, height, true, labels); | |
| try{ | |
| recordReader.initialize(new FileSplit(this.trainingFolder)); | |
| } catch(InterruptedException ie) { | |
| ie.printStackTrace(); | |
| } | |
| DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, width * height, labels.size()); | |
| Nd4j.ENFORCE_NUMERICAL_STABILITY = true; | |
| log.info("Build model...."); | |
| int numRows = height; | |
| int numColumns = width; | |
| int nChannels = 1; | |
| int outputNum = outputs; | |
| int numSamples = samples; | |
| int batchSize = 10; | |
| int iterations = 1; | |
| int seed = 123; | |
| int listenerFreq = 5; | |
| MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() | |
| .seed(seed) | |
| .iterations(iterations) | |
| //.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) | |
| .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
| .learningRate(0.01) // default | |
| .regularization(true) | |
| .list(6) | |
| .layer(0, new ConvolutionLayer.Builder(3, 3) // 40*32*3 => 40*32*10 | |
| .nIn(nChannels) | |
| .nOut(10) | |
| .padding(1, 1) | |
| .stride(1, 1) | |
| .weightInit(WeightInit.RELU) | |
| .activation("relu") | |
| .build()) | |
| .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2,2}) // 40*32*10 => 20*16*10 | |
| .stride(2, 2) | |
| .build()) | |
| .layer(2, new ConvolutionLayer.Builder(3, 3) // 20*16*10 => 20*16*20 | |
| .nIn(nChannels) | |
| .nOut(20) | |
| .padding(1, 1) | |
| .stride(1, 1) | |
| .weightInit(WeightInit.RELU) | |
| .activation("relu") | |
| .build()) | |
| .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2,2}) // 20*16*20 => 10*8*20 = 1,600 | |
| .stride(2, 2) | |
| .build()) | |
| .layer(4, new DenseLayer.Builder().activation("relu") | |
| .nOut(100).build()) | |
| .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.RMSE_XENT) | |
| .nOut(outputNum) | |
| .weightInit(WeightInit.RELU) | |
| .activation("softmax") | |
| .updater(Updater.SGD) | |
| .build()) | |
| .backprop(true).pretrain(false); | |
| new ConvolutionLayerSetup(builder,numRows,numColumns,nChannels); | |
| MultiLayerConfiguration conf = builder.build(); | |
| MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
| model.init(); | |
| log.info("Train model...."); | |
| model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq), new HistogramIterationListener(listenerFreq))); | |
| while(iter.hasNext()) { | |
| DataSet dataSet = iter.next(); | |
| model.fit(dataSet); | |
| } | |
| log.info("Evaluate weights...."); | |
| iter.reset(); | |
| log.info("Evaluate model...."); | |
| Evaluation eval = new Evaluation(outputNum); | |
| while(iter.hasNext()) { | |
| DataSet dataSet = iter.next(); | |
| INDArray output = model.output(dataSet.getFeatureMatrix()); | |
| eval.eval(dataSet.getLabels(), output); | |
| log.info(eval.stats()); | |
| } | |
| log.info("****************Example finished********************"); | |
| } | |
| public static void main(String[] args) { | |
| TrainPeople trainPeople = new TrainPeople(new File(args[0])); | |
| try { | |
| trainPeople.execute(); | |
| } catch (IOException e) { | |
| e.printStackTrace(); | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment