Created
October 16, 2019 09:29
-
-
Save choowilson/1272bb2d94f230c089871e8d4ecaaedd to your computer and use it in GitHub Desktop.
research to check the effect of ImageTransform on Dataset original label/annotation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import net.lingala.zip4j.core.ZipFile; | |
import net.lingala.zip4j.exception.ZipException; | |
import org.bytedeco.javacv.CanvasFrame; | |
import org.bytedeco.javacv.Frame; | |
import org.bytedeco.javacv.FrameGrabber; | |
import org.bytedeco.javacv.OpenCVFrameConverter; | |
import org.bytedeco.opencv.opencv_core.Mat; | |
import org.bytedeco.opencv.opencv_core.Point; | |
import org.bytedeco.opencv.opencv_core.Scalar; | |
import org.bytedeco.opencv.opencv_core.Size; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.image.loader.NativeImageLoader; | |
//import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader; | |
import org.datavec.image.transform.BoxImageTransform; | |
import org.datavec.image.transform.ColorConversionTransform; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.ConvolutionMode; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.WorkspaceMode; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.layers.objdetect.DetectedObject; | |
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; | |
import org.deeplearning4j.nn.transferlearning.TransferLearning; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.deeplearning4j.zoo.model.TinyYOLO; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.buffer.DataType; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.io.ClassPathResource; | |
import java.io.File; | |
import java.io.IOException; | |
import java.util.List; | |
import java.util.Random; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.deeplearning4j.ui.api.UIServer; | |
import static org.bytedeco.opencv.global.opencv_core.CV_8U; | |
import static org.bytedeco.opencv.global.opencv_core.flip; | |
import static org.bytedeco.opencv.global.opencv_imgproc.*; | |
import java.awt.event.KeyEvent; | |
import org.deeplearning4j.nn.layers.objdetect.YoloUtils.*; | |
public class RealTimeABDetector { | |
private static final Logger log = LoggerFactory.getLogger(RealTimeABDetector.class); | |
private static int nChannels = 3; | |
private static final int gridWidth = 13; | |
private static final int gridHeight = 13; | |
private static double detectionThreshold = 0.1; | |
private static final int tinyyolowidth = 416; | |
private static final int tinyyoloheight = 416; | |
private static int nBoxes = 5; | |
private static double lambdaNoObj = 0.5; | |
private static double lambdaCoord = 1.0; | |
private static double[][] priorBoxes = {{2, 5}, {2.5, 6}, {3, 7}, {3.5, 8}, {4, 9}}; | |
private static int batchSize = 8; | |
// private static int nEpochs = 40; | |
private static int nEpochs = 10; | |
private static double learningRate = 1e-4; | |
private static int nClasses = 2; | |
private static List<String> labels; | |
private static int seed = 100; | |
private static Random rng = new Random(seed); | |
private static File modelFilename = new File(System.getProperty("user.dir"),"generated-models/Avocado_Banana_Detector.zip"); | |
private static ComputationGraph model; | |
private static Frame frame = null; | |
public static void main(String[] args) throws Exception { | |
unzipAllDataSet(); | |
File trainDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/fruits/train/"); | |
File testDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/fruits/test/"); | |
log.info("Load data..."); | |
FileSplit trainData = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, rng); | |
FileSplit testData = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, rng); | |
// ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(tinyyoloheight, tinyyolowidth, nChannels, | |
// gridHeight, gridWidth, new LabelImgXmlLabelProvider(trainDir), new BoxImageTransform(tinyyoloheight,tinyyolowidth)); | |
// recordReaderTrain.initialize(trainData); | |
// ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(tinyyoloheight, tinyyolowidth, nChannels, | |
// gridHeight, gridWidth, new LabelImgXmlLabelProvider(testDir), new BoxImageTransform(tinyyoloheight,tinyyolowidth)); | |
// recordReaderTest.initialize(testData); | |
// ObjectDetectionRRCustom recordReaderTrain = new ObjectDetectionRRCustom(tinyyoloheight, tinyyolowidth, nChannels, | |
// gridHeight, gridWidth, new LabelImgXmlLabelProvider(trainDir), new BoxImageTransform(tinyyoloheight,tinyyolowidth)); | |
// recordReaderTrain.initialize(trainData); | |
// ObjectDetectionRRCustom recordReaderTest = new ObjectDetectionRRCustom(tinyyoloheight, tinyyolowidth, nChannels, | |
// gridHeight, gridWidth, new LabelImgXmlLabelProvider(testDir), new BoxImageTransform(tinyyoloheight,tinyyolowidth)); | |
// recordReaderTest.initialize(testData); | |
ObjectDetectionRRCustom recordReaderTrain = new ObjectDetectionRRCustom(tinyyoloheight, tinyyolowidth, nChannels, | |
gridHeight, gridWidth, new LabelImgXmlLabelProvider(trainDir)); | |
recordReaderTrain.initialize(trainData); | |
ObjectDetectionRRCustom recordReaderTest = new ObjectDetectionRRCustom(tinyyoloheight, tinyyolowidth, nChannels, | |
gridHeight, gridWidth, new LabelImgXmlLabelProvider(testDir)); | |
recordReaderTest.initialize(testData); | |
// | |
// RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true); | |
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, 1, 1, 1, true); | |
train.setPreProcessor(new ImagePreProcessingScaler(0, 1)); | |
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true); | |
test.setPreProcessor(new ImagePreProcessingScaler(0, 1)); | |
// | |
// // Print Labels | |
labels = train.getLabels(); | |
// System.out.println(labels); | |
// | |
// //If model already exist, evaluate it and then run real time object detection inference, else train the model. | |
if (modelFilename.exists()) { | |
// Load trained model from previous execution | |
Nd4j.getRandom().setSeed(seed); | |
log.info("Load model..."); | |
model = ModelSerializer.restoreComputationGraph(modelFilename); | |
} else { | |
Nd4j.getRandom().setSeed(seed); | |
ComputationGraph pretrained = null; | |
FineTuneConfiguration fineTuneConf = null; | |
INDArray priors = Nd4j.create(priorBoxes); | |
/* STEP 1: Transfer Learning steps - Load TinyYOLO prebuilt model. */ | |
log.info("Build model..."); | |
pretrained = (ComputationGraph)TinyYOLO.builder().build().initPretrained(); | |
/* STEP 2: Transfer Learning steps - Model Configurations. */ | |
fineTuneConf = getFineTuneConfiguration(); | |
/* STEP 3: Transfer Learning steps - Modify prebuilt model's architecture */ | |
model = getNewComputationGraph(pretrained, priors, fineTuneConf); | |
System.out.println(model.summary(InputType.convolutional(tinyyoloheight, tinyyolowidth, nClasses))); | |
/* STEP 4: Training and Save model. */ | |
log.info("Train model..."); | |
UIServer server = UIServer.getInstance(); | |
StatsStorage storage = new InMemoryStatsStorage(); | |
server.attach(storage); | |
model.setListeners(new ScoreIterationListener(1), new StatsListener(storage)); | |
for (int i = 1; i < nEpochs+1; i++) { | |
train.reset(); | |
while (train.hasNext()) { | |
model.fit(train.next()); | |
} | |
log.info("*** Completed epoch {} ***", i); | |
} | |
ModelSerializer.writeModel(model, modelFilename, true); | |
System.out.println("Model saved."); | |
} | |
// /* STEP 5: Perform offline validation with Test data. */ | |
// OfflineValidationWithTestDataset(test); | |
OfflineValidationWithTestDataset(train); | |
//// doInference(); | |
} | |
private static ComputationGraph getNewComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) { | |
ComputationGraph _ComputationGraph = new TransferLearning.GraphBuilder(pretrained) | |
.fineTuneConfiguration(fineTuneConf) | |
.removeVertexKeepConnections("conv2d_9") | |
.removeVertexKeepConnections("outputs") | |
.addLayer("convolution2d_9", | |
new ConvolutionLayer.Builder(1, 1) | |
.nIn(1024) | |
.nOut(nBoxes * (5 + nClasses)) | |
.stride(1, 1) | |
.convolutionMode(ConvolutionMode.Same) | |
.weightInit(WeightInit.XAVIER) | |
.activation(Activation.IDENTITY) | |
.build(), | |
"leaky_re_lu_8") | |
.addLayer("outputs", | |
new Yolo2OutputLayer.Builder() | |
.lambbaNoObj(lambdaNoObj) | |
.lambdaCoord(lambdaCoord) | |
.boundingBoxPriors(priors.castTo(DataType.FLOAT)) | |
.build(), | |
"convolution2d_9") | |
.setOutputs("outputs") | |
.build(); | |
return _ComputationGraph; | |
} | |
private static FineTuneConfiguration getFineTuneConfiguration() { | |
FineTuneConfiguration _FineTuneConfiguration = new FineTuneConfiguration.Builder() | |
.seed(seed) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) | |
.gradientNormalizationThreshold(1.0) | |
.updater(new Adam.Builder().learningRate(learningRate).build()) | |
.l2(0.00001) | |
.activation(Activation.IDENTITY) | |
.trainingWorkspaceMode(WorkspaceMode.ENABLED) | |
.inferenceWorkspaceMode(WorkspaceMode.ENABLED) | |
.build(); | |
return _FineTuneConfiguration; | |
} | |
// Manually Evaluate the performance of the object detection model | |
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test)throws InterruptedException{ | |
System.out.println("Start validation"); | |
NativeImageLoader imageLoader = new NativeImageLoader(); | |
CanvasFrame canvas = new CanvasFrame("Validate Test Dataset"); | |
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); | |
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); | |
Mat convertedMat = new Mat(); | |
Mat convertedMat_big = new Mat(); | |
while (test.hasNext() && canvas.isVisible()) { | |
// while (train.hasNext() && canvas.isVisible()) { | |
org.nd4j.linalg.dataset.DataSet ds = test.next(); | |
INDArray features = ds.getFeatures(); | |
INDArray results = model.outputSingle(features); | |
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold); | |
List<DetectedObject> objects = NonMaxSuppression.getObjects(objs); | |
Mat mat = imageLoader.asMat(features); | |
mat.convertTo(convertedMat, CV_8U, 255, 0); | |
// int w = mat.cols() * 2; | |
// int h = mat.rows() * 2; | |
int w = mat.cols(); | |
int h = mat.rows(); | |
resize(convertedMat,convertedMat_big, new Size(w, h)); | |
for (DetectedObject obj : objects) { | |
double[] xy1 = obj.getTopLeftXY(); | |
double[] xy2 = obj.getBottomRightXY(); | |
String label = labels.get(obj.getPredictedClass()); | |
int x1 = (int) Math.round(w * xy1[0] / gridWidth); | |
int y1 = (int) Math.round(h * xy1[1] / gridHeight); | |
int x2 = (int) Math.round(w * xy2[0] / gridWidth); | |
int y2 = (int) Math.round(h * xy2[1] / gridHeight); | |
rectangle(convertedMat_big, new Point(x1, y1), new Point(x2, y2), Scalar.RED, 2, 0, 0); | |
putText(convertedMat_big, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN); | |
} | |
canvas.showImage(converter.convert(convertedMat_big)); | |
canvas.waitKey(); | |
} | |
canvas.dispose(); | |
} | |
// | |
// // Stream video frames from Webcam and run them through TinyYOLO model and get predictions | |
// private static void doInference(){ | |
// | |
// String cameraPos = "front"; | |
// int cameraNum = 0; | |
// Thread thread = null; | |
// NativeImageLoader loader = new NativeImageLoader(tinyyolowidth, tinyyoloheight, 3, new ColorConversionTransform(COLOR_BGR2RGB)); | |
// ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1); | |
// | |
// if( !cameraPos.equals("front") && !cameraPos.equals("back") ) | |
// { | |
// try { | |
// throw new Exception("Unknown argument for camera position. Choose between front and back"); | |
// } catch (Exception e) { | |
// e.printStackTrace(); | |
// } | |
// } | |
// | |
// FrameGrabber grabber = null; | |
// try { | |
// grabber = FrameGrabber.createDefault(cameraNum); | |
// } catch (FrameGrabber.Exception e) { | |
// e.printStackTrace(); | |
// } | |
// OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); | |
// | |
// try { | |
// grabber.start(); | |
// } catch (FrameGrabber.Exception e) { | |
// e.printStackTrace(); | |
// } | |
// | |
// String winName = "Object Detection"; | |
// CanvasFrame canvas = new CanvasFrame(winName); | |
// | |
// int w = grabber.getImageWidth(); | |
// int h = grabber.getImageHeight(); | |
// | |
// | |
// canvas.setCanvasSize(w, h); | |
// while (true) | |
// { | |
// try { | |
// frame = grabber.grab(); | |
// } catch (FrameGrabber.Exception e) { | |
// e.printStackTrace(); | |
// } | |
// | |
// //if a thread is null, create new thread | |
// if (thread == null) | |
// { | |
// thread = new Thread(() -> | |
// { | |
// while (frame != null) | |
// { | |
// try | |
// { | |
// Mat rawImage = new Mat(); | |
// | |
// //Flip the camera if opening front camera | |
// if(cameraPos.equals("front")) | |
// { | |
// Mat inputImage = converter.convert(frame); | |
// flip(inputImage, rawImage, 1); | |
// } | |
// else | |
// { | |
// rawImage = converter.convert(frame); | |
// } | |
// | |
// Mat resizeImage = new Mat(); | |
// resize(rawImage, resizeImage, new Size(tinyyolowidth, tinyyoloheight)); | |
// | |
// INDArray inputImage = loader.asMatrix(resizeImage); | |
// scaler.transform(inputImage); | |
// INDArray outputs = model.outputSingle(inputImage); | |
// org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); | |
// List<DetectedObject> objs = yout.getPredictedObjects(outputs, detectionThreshold); | |
// List<DetectedObject> objects = NonMaxSuppression.getObjects(objs); | |
// | |
// for (DetectedObject obj : objects) { | |
// | |
// double[] xy1 = obj.getTopLeftXY(); | |
// double[] xy2 = obj.getBottomRightXY(); | |
// String label = labels.get(obj.getPredictedClass()); | |
// int x1 = (int) Math.round(w * xy1[0] / gridWidth); | |
// int y1 = (int) Math.round(h * xy1[1] / gridHeight); | |
// int x2 = (int) Math.round(w * xy2[0] / gridWidth); | |
// int y2 = (int) Math.round(h * xy2[1] / gridHeight); | |
// rectangle(rawImage, new Point(x1, y1), new Point(x2, y2), Scalar.RED, 2, 0, 0); | |
// putText(rawImage, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN); | |
// } | |
// canvas.showImage(converter.convert(rawImage)); | |
// } | |
// catch (Exception e) | |
// { | |
// throw new RuntimeException(e); | |
// } | |
// } | |
// }); | |
// thread.start(); | |
// } | |
// | |
// KeyEvent t = null; | |
// try { | |
// t = canvas.waitKey(33); | |
// } catch (InterruptedException e) { | |
// e.printStackTrace(); | |
// } | |
// | |
// if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) { | |
// break; | |
// } | |
// } | |
// } | |
// | |
//To unzip the training and test datset | |
public static void unzip(String source, String destination){ | |
try { | |
ZipFile zipFile = new ZipFile(source); | |
zipFile.extractAll(destination); | |
} catch (ZipException e) { | |
e.printStackTrace(); | |
} | |
} | |
public static void unzipAllDataSet(){ | |
//unzip training data set | |
File resourceDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/fruits"); | |
if (!resourceDir.exists()) resourceDir.mkdirs(); | |
String zipTrainFilePath = null; | |
String zipTestFilePath = null; | |
try { | |
zipTrainFilePath = new ClassPathResource("fruits/train.zip").getFile().toString(); | |
zipTestFilePath = new ClassPathResource("fruits/test.zip").getFile().toString(); | |
System.out.println(zipTrainFilePath); | |
System.out.println(zipTestFilePath); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
File trainFolder = new File(resourceDir+"/train"); | |
if (!trainFolder.exists()) unzip(zipTrainFilePath, resourceDir.toString()); | |
System.out.println("unziptrain done"); | |
System.out.println(trainFolder); | |
File testFolder = new File(resourceDir+"/test"); | |
if (!testFolder.exists()) unzip(zipTestFilePath, resourceDir.toString()); | |
System.out.println(testFolder); | |
System.out.println("unziptest done"); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment