-
-
Save treo/aa30149afd89a0ece035e3876177bf29 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* To change this license header, choose License Headers in Project Properties. | |
* To change this template file, choose Tools | Templates | |
* and open the template in the editor. | |
*/ | |
package com.ldt.plateregconition; | |
import java.io.DataInputStream; | |
import java.io.DataOutputStream; | |
import java.io.File; | |
import java.io.FileInputStream; | |
import java.io.IOException; | |
import java.io.OutputStream; | |
import java.nio.file.Files; | |
import java.nio.file.Paths; | |
import java.util.Arrays; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Random; | |
import org.apache.commons.io.FileUtils; | |
import org.deeplearning4j.datasets.iterator.DataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | |
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.api.IterationListener; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
public class LicensePlatePCNN { | |
private static final Logger log = LoggerFactory.getLogger(LicensePlatePCNN.class); | |
private int height; | |
private int width; | |
private int channels; | |
private int outputNum; | |
private long seed; | |
private int numExamples; | |
private int iterations; | |
private int batchSize; | |
private int nEpochs; | |
private int splitTrainNum; | |
private int listenerFreq; | |
private String directory; | |
public static String FILE_PARAMS = "params.bin"; | |
public static String FILE_MODEL = "model.json"; | |
public LicensePlatePCNN() { | |
} | |
public LicensePlatePCNN(int height, int width, int channels, int outputNum, long seed, int iterations, int batchSize, int nEpochs) { | |
this.height = height; | |
this.width = width; | |
this.channels = channels; | |
this.outputNum = outputNum; | |
this.seed = seed; | |
this.iterations = iterations; | |
this.nEpochs = nEpochs; | |
this.batchSize = batchSize; | |
splitTrainNum = (int) (batchSize * 0.8); | |
listenerFreq = iterations / 5; | |
} | |
public MultiLayerNetwork init() { | |
log.info("Build model......"); | |
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.iterations(iterations) | |
.learningRate(0.001) | |
.weightInit(WeightInit.XAVIER) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.list(4) | |
.layer(0, new ConvolutionLayer.Builder(5, 5) | |
.nIn(channels) | |
.stride(1, 1) | |
.nOut(50) | |
.activation("relu") | |
.build()) | |
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.stride(2, 2) | |
.kernelSize(2, 2) | |
.build()) | |
.layer(2, new DenseLayer.Builder().activation("relu") | |
.nIn(20 * 2 * 4) | |
.nOut(1000) | |
.build()) | |
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.nOut(outputNum) | |
.activation("softmax") | |
.build()) | |
.backprop(true).pretrain(false); | |
new ConvolutionLayerSetup(builder, height, width, channels); | |
MultiLayerConfiguration conf = builder.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
return model; | |
} | |
public void trainModel(DataSetIterator iterator, MultiLayerNetwork model) { | |
log.info("Train model......"); | |
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq))); | |
List<DataSet> featuresTest = new ArrayList<>(); | |
List<DataSet> featuresTrain = new ArrayList<>(); | |
log.info("Get Data"); | |
while (iterator.hasNext()) { | |
DataSet ds = iterator.next(); | |
ds.normalizeZeroMeanZeroUnitVariance(); | |
ds.shuffle(); | |
SplitTestAndTrain split = ds.splitTestAndTrain(splitTrainNum, new Random(seed)); | |
DataSet train = split.getTrain(); | |
featuresTrain.add(train); | |
DataSet dsTest = split.getTest(); | |
featuresTest.add(dsTest); | |
} | |
log.info("Num of train minibatch: " + featuresTrain.size()); | |
for (int i = 0; i < nEpochs; i++) { | |
log.info("*** Begin epoch {} ***", i); | |
for(DataSet s : featuresTrain){ | |
model.fit(s); | |
} | |
log.info("*** Completed epoch {} ***", i); | |
log.info("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
for (DataSet ds1 : featuresTest) { | |
INDArray output = model.output(ds1.getFeatureMatrix()); | |
eval.eval(ds1.getLabels(), output); | |
} | |
log.info(eval.stats()); | |
} | |
log.info("******************Finished********************"); | |
} | |
public void trainModelWithMNIST(MultiLayerNetwork model) throws IOException { | |
log.info("Train model......"); | |
log.info("Load data...."); | |
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); | |
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); | |
log.info("Train model...."); | |
model.setListeners(new ScoreIterationListener(1)); | |
for (int i = 0; i < nEpochs; i++) { | |
model.fit(mnistTrain); | |
log.info("*** Completed epoch {} ***", i); | |
log.info("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
while (mnistTest.hasNext()) { | |
DataSet ds = mnistTest.next(); | |
INDArray output = model.output(ds.getFeatureMatrix()); | |
eval.eval(ds.getLabels(), output); | |
} | |
log.info(eval.stats()); | |
mnistTest.reset(); | |
} | |
log.info("****************Example finished********************"); | |
} | |
public void predict(INDArray input, MultiLayerNetwork model) { | |
INDArray output = model.output(input); | |
output = Nd4j.argMax(output, 1); | |
System.out.println(output.getDouble(0)); | |
} | |
public void saveModel(MultiLayerNetwork model) throws IOException { | |
log.info("Save model............"); | |
OutputStream fos = Files.newOutputStream(Paths.get(FILE_PARAMS)); | |
DataOutputStream dos = new DataOutputStream(fos); | |
Nd4j.write(model.params(), dos); | |
dos.flush(); | |
dos.close(); | |
FileUtils.write(new File(FILE_MODEL), model.getLayerWiseConfigurations().toJson()); | |
} | |
public MultiLayerNetwork loadModel() throws IOException { | |
log.info("Load model..........."); | |
MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(FILE_MODEL))); | |
DataInputStream dis = new DataInputStream(new FileInputStream(FILE_PARAMS)); | |
INDArray newParams = Nd4j.read(dis); | |
dis.close(); | |
MultiLayerNetwork savedNetwork = new MultiLayerNetwork(confFromJson); | |
savedNetwork.init(); | |
savedNetwork.setParameters(newParams); | |
savedNetwork.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq))); | |
return savedNetwork; | |
} | |
public int getHeight() { | |
return height; | |
} | |
public void setHeight(int height) { | |
this.height = height; | |
} | |
public int getWidth() { | |
return width; | |
} | |
public void setWidth(int width) { | |
this.width = width; | |
} | |
public int getChannels() { | |
return channels; | |
} | |
public void setChannels(int channels) { | |
this.channels = channels; | |
} | |
public int getOutputNum() { | |
return outputNum; | |
} | |
public void setOutputNum(int outputNum) { | |
this.outputNum = outputNum; | |
} | |
public long getSeed() { | |
return seed; | |
} | |
public void setSeed(long seed) { | |
this.seed = seed; | |
} | |
public int getIterations() { | |
return iterations; | |
} | |
public void setIterations(int iterations) { | |
this.iterations = iterations; | |
} | |
public int getBatchSize() { | |
return batchSize; | |
} | |
public void setBatchSize(int batchSize) { | |
this.batchSize = batchSize; | |
} | |
public int getnEpochs() { | |
return nEpochs; | |
} | |
public void setnEpochs(int nEpochs) { | |
this.nEpochs = nEpochs; | |
} | |
public int getSplitTrainNum() { | |
return splitTrainNum; | |
} | |
public void setSplitTrainNum(int splitTrainNum) { | |
this.splitTrainNum = splitTrainNum; | |
} | |
public int getListenerFreq() { | |
return listenerFreq; | |
} | |
public void setListenerFreq(int listenerFreq) { | |
this.listenerFreq = listenerFreq; | |
} | |
public String getDirectory() { | |
return directory; | |
} | |
public void setDirectory(String directory) { | |
this.directory = directory; | |
} | |
public int getNumExamples() { | |
return numExamples; | |
} | |
public void setNumExamples(int numExamples) { | |
this.numExamples = numExamples; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment