Skip to content

Instantly share code, notes, and snippets.

@wmalarski
Created June 9, 2019 16:32
Show Gist options
  • Save wmalarski/b6a0f8c220b15f39fa74f67f215ad544 to your computer and use it in GitHub Desktop.
Save wmalarski/b6a0f8c220b15f39fa74f67f215ad544 to your computer and use it in GitHub Desktop.
clear;clc;
% https://www.mathworks.com/help/deeplearning/examples/train-deep-learning-network-to-classify-new-images.html
%% Loading monkey data
imdsValidation = imageDatastore('10-monkey-species/validation', 'IncludeSubfolders',true,'LabelSource','foldernames');
imdsTrain = imageDatastore('10-monkey-species/training', 'IncludeSubfolders',true,'LabelSource','foldernames');
%% Selecting cnn model
net = resnet50;
inputSize = net.Layers(1).InputSize;
% analyzeNetwork(net);
%% Extract the layer graph from the trained network.
% If the network is a SeriesNetwork object, such as AlexNet,
% VGG-16, or VGG-19, then convert the list of layers in net.
% Layers to a layer graph.
if isa(net,'SeriesNetwork')
lgraph = layerGraph(net.Layers);
else
lgraph = layerGraph(net);
end
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
[learnableLayer,classLayer]
%% podmiana FullyConnectedLayer
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
newLearnableLayer = fullyConnectedLayer(numClasses, ...
'Name','new_fc', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
newLearnableLayer = convolution2dLayer(1,numClasses, ...
'Name','new_conv', ...
'WeightLearnRateFactor',10, ...
'BiasLearnRateFactor',10);
end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
%% podmiana wyjscia
newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
%analyzeNetwork(lgraph);
%% zamrożenie warstw
layers = lgraph.Layers;
connections = lgraph.Connections;
layers(1:end-3) = freezeWeights(layers(1:end-3));
lgraph = createLgraphUsingConnections(layers,connections);
%% augmented image datastore
pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
'RandXReflection',true, ...
'RandXTranslation',pixelRange, ...
'RandYTranslation',pixelRange, ...
'RandXScale',scaleRange, ...
'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
'DataAugmentation',imageAugmenter);
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
%% training options
miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
'MiniBatchSize',miniBatchSize, ...
'MaxEpochs',6, ...
'InitialLearnRate',3e-4, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',valFrequency, ...
'Verbose',false, ...
'Plots','training-progress');
%% training
trainedNet = trainNetwork(augimdsTrain,lgraph,options);
[YPredValid,scoresValid] = classify(trainedNet, augimdsValidation);
validationLabels = grp2idx(imdsValidation.Labels);
validationResult = grp2idx(YPredValid);
[YPredTrain,scoresTrain] = classify(trainedNet, augimdsTrain);
trainLabels = grp2idx(imdsTrain.Labels);
trainResult = grp2idx(YPredTrain);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment