Created
January 25, 2018 13:48
-
-
Save brlauuu/cab4396f48f368df755f9df06483827d to your computer and use it in GitHub Desktop.
Training CNN using Matlab R2017b NN and CV toolbox
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
% Load training data. | |
imageDir = fullfile('training_data'); | |
labelDir = fullfile('label_data'); | |
% Create an image datastore for the images. | |
imds = imageDatastore(imageDir, 'IncludeSubfolders',true, ... | |
'LabelSource','foldernames'); | |
% Create a pixelLabelDatastore for the ground truth pixel labels. | |
classNames = ["good","bad"]; | |
labelIDs = [255 0]; | |
pxds = pixelLabelDatastore(labelDir, classNames, labelIDs); | |
%%%%%% Visualize training images and ground truth pixel labels. | |
I = read(imds); | |
C = read(pxds); | |
figure | |
I = imresize(I,5); | |
L = imresize(uint8(C),5); | |
imshowpair(I,L,'montage') | |
% Create a semantic segmentation network. This network uses a simple | |
% semantic segmentation network based on a downsampling and upsampling | |
% design. | |
numFilters = 40; | |
filterSize = 15; | |
numClasses = numel(categories(imds.Labels)); | |
layers = [ | |
imageInputLayer([64 64 1]) | |
convolution2dLayer(filterSize,numFilters,'Padding',0) | |
maxPooling2dLayer(2,'Stride',2) | |
convolution2dLayer(filterSize,numFilters/2,'Padding',0) | |
maxPooling2dLayer(2,'Stride',2) | |
reluLayer() | |
transposedConv2dLayer(filterSize,numFilters,'Stride',2); | |
convolution2dLayer(1,numClasses); | |
softmaxLayer() | |
pixelClassificationLayer() | |
] | |
% Setup training options. | |
opts = trainingOptions('sgdm', ... | |
'InitialLearnRate', 1e-3, ... | |
'MaxEpochs', 50, ... | |
'MiniBatchSize', 100); | |
% Create a data source for training data. | |
trainingData = pixelLabelImageSource(imds,pxds); | |
% Train the network. | |
net = trainNetwork(trainingData,layers,opts); | |
mem_net = net | |
save mem_net | |
% Read and display a test image. | |
testImage = imread('fil028.jpg'); | |
figure | |
imshow(testImage) | |
% Segment the test image and display the results. | |
%C = semanticseg(testImage,net); | |
%B = labeloverlay(testImage,C); | |
%figure | |
%imshow(B) | |
%%%%%% Improve the results if necessary | |
% The network failed to property segment the triangles and classified every | |
% pixel as "background". The training appeared to be going well with | |
% training accuracies greater than 90%. However, the network only learned | |
% to classify the background class. To understand why this happened, you | |
% can count the occurrence of each pixel label across the dataset. | |
% The majority of pixel labels are for the background. The poor results are | |
% due to the class imbalance. Class imbalance biases the learning process | |
% in favor of the dominant class. That's why every pixel is classified as | |
% "background". To fix this, use class weighting to balance the classes. | |
% There are several methods for computing class weights. One common method | |
% is inverse frequency weighting where the class weights are the inverse of | |
% the class frequencies. This increases weight given to under-represented | |
% classes. | |
%tbl = countEachLabel(trainingData) | |
%totalNumberOfPixels = sum(tbl.PixelCount); | |
%frequency = tbl.PixelCount / totalNumberOfPixels; | |
%classWeights = 1./frequency | |
%layers(end) = pixelClassificationLayer('ClassNames',tbl.Name,'ClassWeights',classWeights); | |
%net = trainNetwork(trainingData,layers,opts); | |
% Try to segment the test image again. | |
C = semanticseg(testImage,net); | |
B = labeloverlay(testImage,C); | |
figure | |
imshow(B) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment