Created
April 9, 2020 07:29
-
-
Save yptheangel/17562d4b621285742ce0b4afe779ce11 to your computer and use it in GitHub Desktop.
an example of dl4j example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* | |
* * ****************************************************************************** | |
* * * Copyright (c) 2019 Skymind AI Bhd. | |
* * * Copyright (c) 2020 CertifAI Sdn. Bhd. | |
* * * | |
* * * This program and the accompanying materials are made available under the | |
* * * terms of the Apache License, Version 2.0 which is available at | |
* * * https://www.apache.org/licenses/LICENSE-2.0. | |
* * * | |
* * * Unless required by applicable law or agreed to in writing, software | |
* * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
* * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
* * * License for the specific language governing permissions and limitations | |
* * * under the License. | |
* * * | |
* * * SPDX-License-Identifier: Apache-2.0 | |
* * ***************************************************************************** | |
* | |
* | |
*/ | |
import org.bytedeco.javacv.CanvasFrame; | |
import org.bytedeco.javacv.Frame; | |
import org.bytedeco.javacv.OpenCVFrameConverter; | |
import org.bytedeco.opencv.opencv_core.Mat; | |
import org.bytedeco.opencv.opencv_core.Point; | |
import org.bytedeco.opencv.opencv_core.Scalar; | |
import org.bytedeco.opencv.opencv_core.Size; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.ConvolutionMode; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.WorkspaceMode; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.layers.objdetect.DetectedObject; | |
import org.deeplearning4j.nn.layers.objdetect.YoloUtils; | |
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; | |
import org.deeplearning4j.nn.transferlearning.TransferLearning; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.deeplearning4j.zoo.model.TinyYOLO; | |
import org.deeplearning4j.zoo.model.YOLO2; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.File; | |
import java.util.List; | |
import static org.bytedeco.opencv.global.opencv_core.CV_8U; | |
import static org.bytedeco.opencv.global.opencv_imgproc.*; | |
import static org.bytedeco.opencv.helper.opencv_core.RGB; | |
public class SmokerDetectionYOLO2 { | |
private static final Logger log = LoggerFactory.getLogger(SmokerDetectionYOLO2.class); | |
private static int seed = 123; | |
private static double detectionThreshold = 0.3; | |
private static int nBoxes = 5; | |
// private static double lambdaNoObj = 0.5; | |
// private static double lambdaCoord = 5.0; | |
private static double lambdaNoObj = 0.5; | |
private static double lambdaCoord = 1.0; | |
// private static double[][] priorBoxes = {{1, 3}, {2.5, 6}, {3, 4}, {3.5, 8}, {4, 9}}; | |
private static double[][] priorBoxes ={{1.2872442749452366,0.4364610640580757},{0.77041163566991,0.7527746716574517},{2.1883878199554867,0.9689619659628125},{1.1871515452191663,1.8599969129812142},{2.7383268509632757,2.586608240477847}}; | |
// private static int batchSize = 16; | |
private static int batchSize = 8; | |
private static int nEpochs = 20; | |
private static double learningRate = 1e-4; | |
// private static int nClasses = 2; | |
private static int nClasses = 1; | |
private static List<String> labels; | |
private static File modelFilename = new File(System.getProperty("user.dir"), "SmokerDetector_yolo2.zip"); | |
private static ComputationGraph model; | |
private static Frame frame = null; | |
private static final Scalar GREEN = RGB(0, 255.0, 0); | |
private static final Scalar YELLOW = RGB(255, 255, 0); | |
private static Scalar[] colormap = {GREEN, YELLOW}; | |
private static String labeltext = null; | |
public static void main(String[] args) throws Exception { | |
// STEP 1 : Create iterators | |
SmokerDatasetIterator.setup(); | |
RecordReaderDataSetIterator trainIter = SmokerDatasetIterator.trainIterator(batchSize); | |
RecordReaderDataSetIterator testIter = SmokerDatasetIterator.testIterator(1); | |
labels = trainIter.getLabels(); | |
// If model does not exist, train the model, else directly go to model evaluation and then run real time object detection inference. | |
if (modelFilename.exists()) { | |
// STEP 2 : Load trained model from previous execution | |
Nd4j.getRandom().setSeed(seed); | |
log.info("Load model..."); | |
model = ModelSerializer.restoreComputationGraph(modelFilename); | |
} else { | |
Nd4j.getRandom().setSeed(seed); | |
INDArray priors = Nd4j.create(priorBoxes); | |
// STEP 2 : Train the model using Transfer Learning | |
// STEP 2.1: Transfer Learning steps - Load TinyYOLO prebuilt model. | |
log.info("Build model..."); | |
ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained(); | |
// STEP 2.2: Transfer Learning steps - Model Configurations. | |
FineTuneConfiguration fineTuneConf = getFineTuneConfiguration(); | |
// STEP 2.3: Transfer Learning steps - Modify prebuilt model's architecture | |
model = getComputationGraph(pretrained, priors, fineTuneConf); | |
System.out.println(model.summary(InputType.convolutional( | |
SmokerDatasetIterator.yoloheight, | |
SmokerDatasetIterator.yolowidth, | |
nClasses))); | |
// STEP 2.4: Training and Save model. | |
log.info("Train model..."); | |
UIServer server = UIServer.getInstance(); | |
StatsStorage storage = new InMemoryStatsStorage(); | |
server.attach(storage); | |
model.setListeners(new ScoreIterationListener(1), new StatsListener(storage)); | |
for (int i = 1; i < nEpochs + 1; i++) { | |
trainIter.reset(); | |
while (trainIter.hasNext()) { | |
model.fit(trainIter.next()); | |
} | |
log.info("*** Completed epoch {} ***", i); | |
} | |
// ModelSerializer.writeModel(model, modelFilename, true); | |
ModelSerializer.writeModel(model, modelFilename, false); | |
System.out.println("Model saved."); | |
} | |
// STEP 3: Evaluate the model's accuracy by using the test iterator. | |
OfflineValidationWithTestDataset(testIter); | |
// STEP 4: Inference the model and process the webcam stream and make predictions. | |
} | |
private static ComputationGraph getComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) { | |
return new TransferLearning.GraphBuilder(pretrained) | |
.fineTuneConfiguration(fineTuneConf) | |
.removeVertexKeepConnections("conv2d_23") | |
.removeVertexKeepConnections("outputs") | |
.addLayer("conv2d_23", | |
new ConvolutionLayer.Builder(1, 1) | |
.nIn(1024) | |
.nOut(nBoxes * (5 + nClasses)) | |
.stride(1, 1) | |
.convolutionMode(ConvolutionMode.Same) | |
.weightInit(WeightInit.XAVIER) | |
.activation(Activation.IDENTITY) | |
.build(), | |
"leaky_re_lu_22") | |
.addLayer("outputs", | |
new Yolo2OutputLayer.Builder() | |
.lambdaNoObj(lambdaNoObj) | |
.lambdaCoord(lambdaCoord) | |
.boundingBoxPriors(priors.castTo(DataType.FLOAT)) | |
.build(), | |
"conv2d_23") | |
.setOutputs("outputs") | |
.build(); | |
} | |
private static FineTuneConfiguration getFineTuneConfiguration() { | |
return new FineTuneConfiguration.Builder() | |
.seed(seed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) | |
.gradientNormalizationThreshold(1.0) | |
.updater(new Adam.Builder().learningRate(learningRate).build()) | |
.l2(0.00001) | |
.activation(Activation.IDENTITY) | |
.trainingWorkspaceMode(WorkspaceMode.ENABLED) | |
.inferenceWorkspaceMode(WorkspaceMode.ENABLED) | |
.build(); | |
} | |
// Evaluate visually the performance of the trained object detection model | |
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test) throws InterruptedException { | |
NativeImageLoader imageLoader = new NativeImageLoader(); | |
CanvasFrame canvas = new CanvasFrame("Validate Test Dataset"); | |
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); | |
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0); | |
Mat convertedMat = new Mat(); | |
Mat convertedMat_big = new Mat(); | |
while (test.hasNext() && canvas.isVisible()) { | |
org.nd4j.linalg.dataset.DataSet ds = test.next(); | |
INDArray features = ds.getFeatures(); | |
INDArray results = model.outputSingle(features); | |
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold); | |
YoloUtils.nms(objs, 0.4); | |
Mat mat = imageLoader.asMat(features); | |
mat.convertTo(convertedMat, CV_8U, 255, 0); | |
int w = mat.cols() * 2; | |
int h = mat.rows() * 2; | |
resize(convertedMat, convertedMat_big, new Size(w, h)); | |
convertedMat_big = drawResults(objs, convertedMat_big, w, h); | |
canvas.showImage(converter.convert(convertedMat_big)); | |
canvas.waitKey(); | |
} | |
canvas.dispose(); | |
} | |
private static Mat drawResults(List<DetectedObject> objects, Mat mat, int w, int h) { | |
for (DetectedObject obj : objects) { | |
double[] xy1 = obj.getTopLeftXY(); | |
double[] xy2 = obj.getBottomRightXY(); | |
String label = labels.get(obj.getPredictedClass()); | |
int x1 = (int) Math.round(w * xy1[0] / SmokerDatasetIterator.gridWidth); | |
int y1 = (int) Math.round(h * xy1[1] / SmokerDatasetIterator.gridHeight); | |
int x2 = (int) Math.round(w * xy2[0] / SmokerDatasetIterator.gridWidth); | |
int y2 = (int) Math.round(h * xy2[1] / SmokerDatasetIterator.gridHeight); | |
//Draw bounding box | |
rectangle(mat, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0); | |
//Display label text | |
labeltext = label + " " + String.format("%.2f", obj.getConfidence() * 100) + "%"; | |
int[] baseline = {0}; | |
Size textSize = getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1, 1, baseline); | |
rectangle(mat, new Point(x1 + 2, y2 - 2), new Point(x1 + 2 + textSize.get(0), y2 - 2 - textSize.get(1)), colormap[obj.getPredictedClass()], FILLED, 0, 0); | |
putText(mat, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0, 0, 0)); | |
} | |
return mat; | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment