Skip to content

Instantly share code, notes, and snippets.

@geekprogramming
Last active May 10, 2016 09:55
Show Gist options
  • Save geekprogramming/662c290d27bb9caf81de55e1411d1a27 to your computer and use it in GitHub Desktop.
Save geekprogramming/662c290d27bb9caf81de55e1411d1a27 to your computer and use it in GitHub Desktop.
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package com.ldt.plateregconition;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class LicensePlatePCNN {
private static final Logger log = LoggerFactory.getLogger(LicensePlatePCNN.class);
private int height;
private int width;
private int channels;
private int outputNum;
private long seed;
private int numExamples;
private int iterations;
private int batchSize;
private int nEpochs;
private int splitTrainNum;
private int listenerFreq;
private String directory;
public static String FILE_PARAMS = "params.bin";
public static String FILE_MODEL = "model.json";
public LicensePlatePCNN() {
}
public LicensePlatePCNN(int height, int width, int channels, int outputNum, long seed, int iterations, int batchSize, int nEpochs) {
this.height = height;
this.width = width;
this.channels = channels;
this.outputNum = outputNum;
this.seed = seed;
this.iterations = iterations;
this.nEpochs = nEpochs;
this.batchSize = batchSize;
splitTrainNum = (int) (batchSize * 0.8);
listenerFreq = iterations / 5;
}
public MultiLayerNetwork init() {
log.info("Build model......");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
//.regularization(true).l2(0.0005)
.learningRate(0.001)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list(4)
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(50)
.activation("relu")
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.stride(2, 2)
.kernelSize(2, 2)
.build())
// .layer(2, new ConvolutionLayer.Builder(5, 5)
// .nIn(6)
// .stride(1, 1)
// .nOut(20)
// .activation("relu")
// .build())
// .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
// .stride(2, 2)
// .kernelSize(2, 2)
// .build())
.layer(2, new DenseLayer.Builder().activation("relu")
.nIn(20 * 2 * 4)
.nOut(1000)
.build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation("softmax")
.build())
// .list(4)
// .layer(0, new ConvolutionLayer.Builder(3, 3)
// .nIn(channels)
// .stride(1, 1)
// .padding(1, 1)
// .nOut(50)
// .activation("relu")
// .build())
// .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
// .kernelSize(2, 2)
// .stride(2, 2)
// .build())
// .layer(2, new DenseLayer.Builder().activation("relu")
// .nOut(1000).build())
// .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
// .nOut(outputNum)
// .activation("softmax")
// .build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder, height, width, channels);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
return model;
}
public void trainModel(DataSetIterator iterator, MultiLayerNetwork model) {
log.info("Train model......");
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
//model.setListeners(Arrays.asList((IterationListener) new HistogramIterationListener(listenerFreq)));
//model.setListeners(new HistogramIterationListener(listenerFreq));
List<DataSet> featuresTest = new ArrayList<>();
List<DataSet> featuresTrain = new ArrayList<>();
log.info("Get Data");
while (iterator.hasNext()) {
DataSet ds = iterator.next();
ds.normalizeZeroMeanZeroUnitVariance();
ds.shuffle();
SplitTestAndTrain split = ds.splitTestAndTrain(splitTrainNum, new Random(seed));
DataSet train = split.getTrain();
featuresTrain.add(train);
DataSet dsTest = split.getTest();
featuresTest.add(dsTest);
}
log.info("Num of train minibatch: " + featuresTrain.size());
for (int i = 0; i < nEpochs; i++) {
log.info("*** Begin epoch {} ***", i);
for(DataSet s : featuresTrain){
model.fit(s);
}
log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
for (DataSet ds1 : featuresTest) {
INDArray output = model.output(ds1.getFeatureMatrix());
eval.eval(ds1.getLabels(), output);
}
log.info(eval.stats());
}
log.info("******************Finished********************");
}
public void trainModelWithMNIST(MultiLayerNetwork model) throws IOException {
log.info("Train model......");
//model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
//model.setListeners(Arrays.asList((IterationListener) new HistogramIterationListener(listenerFreq)));
//model.setListeners(new HistogramIterationListener(listenerFreq));
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
log.info("Train model....");
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
model.fit(mnistTrain);
log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while (mnistTest.hasNext()) {
DataSet ds = mnistTest.next();
INDArray output = model.output(ds.getFeatureMatrix());
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
mnistTest.reset();
}
log.info("****************Example finished********************");
}
public void predict(INDArray input, MultiLayerNetwork model) {
INDArray output = model.output(input);
output = Nd4j.argMax(output, 1);
System.out.println(output.getDouble(0));
}
public void saveModel(MultiLayerNetwork model) throws IOException {
log.info("Save model............");
OutputStream fos = Files.newOutputStream(Paths.get(FILE_PARAMS));
DataOutputStream dos = new DataOutputStream(fos);
Nd4j.write(model.params(), dos);
dos.flush();
dos.close();
FileUtils.write(new File(FILE_MODEL), model.getLayerWiseConfigurations().toJson());
}
public MultiLayerNetwork loadModel() throws IOException {
log.info("Load model...........");
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(FILE_MODEL)));
DataInputStream dis = new DataInputStream(new FileInputStream(FILE_PARAMS));
INDArray newParams = Nd4j.read(dis);
dis.close();
MultiLayerNetwork savedNetwork = new MultiLayerNetwork(confFromJson);
savedNetwork.init();
savedNetwork.setParameters(newParams);
savedNetwork.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
return savedNetwork;
}
public int getHeight() {
return height;
}
public void setHeight(int height) {
this.height = height;
}
public int getWidth() {
return width;
}
public void setWidth(int width) {
this.width = width;
}
public int getChannels() {
return channels;
}
public void setChannels(int channels) {
this.channels = channels;
}
public int getOutputNum() {
return outputNum;
}
public void setOutputNum(int outputNum) {
this.outputNum = outputNum;
}
public long getSeed() {
return seed;
}
public void setSeed(long seed) {
this.seed = seed;
}
public int getIterations() {
return iterations;
}
public void setIterations(int iterations) {
this.iterations = iterations;
}
public int getBatchSize() {
return batchSize;
}
public void setBatchSize(int batchSize) {
this.batchSize = batchSize;
}
public int getnEpochs() {
return nEpochs;
}
public void setnEpochs(int nEpochs) {
this.nEpochs = nEpochs;
}
public int getSplitTrainNum() {
return splitTrainNum;
}
public void setSplitTrainNum(int splitTrainNum) {
this.splitTrainNum = splitTrainNum;
}
public int getListenerFreq() {
return listenerFreq;
}
public void setListenerFreq(int listenerFreq) {
this.listenerFreq = listenerFreq;
}
public String getDirectory() {
return directory;
}
public void setDirectory(String directory) {
this.directory = directory;
}
public int getNumExamples() {
return numExamples;
}
public void setNumExamples(int numExamples) {
this.numExamples = numExamples;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment