Created
June 9, 2019 16:32
-
-
Save wmalarski/b6a0f8c220b15f39fa74f67f215ad544 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
clear;clc; | |
% https://www.mathworks.com/help/deeplearning/examples/train-deep-learning-network-to-classify-new-images.html | |
%% Loading monkey data | |
imdsValidation = imageDatastore('10-monkey-species/validation', 'IncludeSubfolders',true,'LabelSource','foldernames'); | |
imdsTrain = imageDatastore('10-monkey-species/training', 'IncludeSubfolders',true,'LabelSource','foldernames'); | |
%% Selecting cnn model | |
net = resnet50; | |
inputSize = net.Layers(1).InputSize; | |
% analyzeNetwork(net); | |
%% Extract the layer graph from the trained network. | |
% If the network is a SeriesNetwork object, such as AlexNet, | |
% VGG-16, or VGG-19, then convert the list of layers in net. | |
% Layers to a layer graph. | |
if isa(net,'SeriesNetwork') | |
lgraph = layerGraph(net.Layers); | |
else | |
lgraph = layerGraph(net); | |
end | |
[learnableLayer,classLayer] = findLayersToReplace(lgraph); | |
[learnableLayer,classLayer] | |
%% podmiana FullyConnectedLayer | |
numClasses = numel(categories(imdsTrain.Labels)); | |
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer') | |
newLearnableLayer = fullyConnectedLayer(numClasses, ... | |
'Name','new_fc', ... | |
'WeightLearnRateFactor',10, ... | |
'BiasLearnRateFactor',10); | |
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer') | |
newLearnableLayer = convolution2dLayer(1,numClasses, ... | |
'Name','new_conv', ... | |
'WeightLearnRateFactor',10, ... | |
'BiasLearnRateFactor',10); | |
end | |
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer); | |
%% podmiana wyjscia | |
newClassLayer = classificationLayer('Name','new_classoutput'); | |
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer); | |
%analyzeNetwork(lgraph); | |
%% zamrożenie warstw | |
layers = lgraph.Layers; | |
connections = lgraph.Connections; | |
layers(1:end-3) = freezeWeights(layers(1:end-3)); | |
lgraph = createLgraphUsingConnections(layers,connections); | |
%% augmented image datastore | |
pixelRange = [-30 30]; | |
scaleRange = [0.9 1.1]; | |
imageAugmenter = imageDataAugmenter( ... | |
'RandXReflection',true, ... | |
'RandXTranslation',pixelRange, ... | |
'RandYTranslation',pixelRange, ... | |
'RandXScale',scaleRange, ... | |
'RandYScale',scaleRange); | |
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ... | |
'DataAugmentation',imageAugmenter); | |
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation); | |
%% training options | |
miniBatchSize = 10; | |
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize); | |
options = trainingOptions('sgdm', ... | |
'MiniBatchSize',miniBatchSize, ... | |
'MaxEpochs',6, ... | |
'InitialLearnRate',3e-4, ... | |
'Shuffle','every-epoch', ... | |
'ValidationData',augimdsValidation, ... | |
'ValidationFrequency',valFrequency, ... | |
'Verbose',false, ... | |
'Plots','training-progress'); | |
%% training | |
trainedNet = trainNetwork(augimdsTrain,lgraph,options); | |
[YPredValid,scoresValid] = classify(trainedNet, augimdsValidation); | |
validationLabels = grp2idx(imdsValidation.Labels); | |
validationResult = grp2idx(YPredValid); | |
[YPredTrain,scoresTrain] = classify(trainedNet, augimdsTrain); | |
trainLabels = grp2idx(imdsTrain.Labels); | |
trainResult = grp2idx(YPredTrain); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment