Created
October 23, 2020 10:06
-
-
Save rebordao/e30d0fbc98a40606613fba95d7412cfa to your computer and use it in GitHub Desktop.
Trains a network by backpropagation through time or by CTC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [net hid_without_bias output] = trains_nn(net, data, targets, phonemes_id); | |
% This function trains a network by backpropagation, by backpropagation | |
% through time or by connectionist temporal classification. | |
% | |
% INPUTS: | |
% net contains the topology and the user-defined training parameters; | |
% data is the train data | |
% targets are the real targets | |
% phonemes_id is the indices of the phonemes in the dictionary file | |
% | |
% OUTPUTS: | |
% net contains the topology and the gradient used for updating the weights | |
% hid_without_bias is the output of the network's hidden units | |
% output is the net's output | |
% | |
% Antonio Rebordao, 2011 | |
% ---------------------------------------------------------------------- | |
nr_obser = size(data, 2); | |
if ~strcmp(net.training_type, 'bp') && ~strcmp(net.training_type, 'bptt') && ~strcmp(net.training_type, 'ctc') | |
error('The training type needs to be bp, bptt or ctc.') | |
end | |
% loads the weights | |
W_inp = net.W_inp; | |
W_rec = net.W_rec; | |
W_out = net.W_out; | |
% FORWARD PASS | |
[hid_without_bias output] = forward_pass(net, data); | |
if strcmp(net.training_type, 'bp') | |
%% BP TRAINING | |
% BACKWARD PASS | |
delta_out = output - targets; | |
delta_hid = 1 - hid_without_bias.^2; | |
delta_hid = delta_hid .* (W_out(:, 1:end-1)' * delta_out); | |
% computes/stores the gradient | |
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]'; | |
d_W_inp = delta_hid * [data; ones(1, nr_obser)]'; | |
elseif strcmp(net.training_type, 'bptt') | |
%% BPTT TRAINING | |
% BACKWARD PASS | |
delta_out = output - targets; | |
delta_hid = W_out(:, 1:end-1)' * delta_out; | |
delta_hid(:, end) = (1 - hid_without_bias(:, end).^2) .* delta_hid(:, end); | |
for steps = nr_obser-1:-1:1 | |
delta_hid(:, steps) = (1 - hid_without_bias(:, steps).^2) .* (delta_hid(:, steps) + (W_rec' * delta_hid(:, steps + 1))); | |
end | |
% computes/stores the gradient | |
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]'; | |
d_W_rec = delta_hid(:, 2:nr_obser) * hid_without_bias(:, 1:end-1)'; | |
d_W_inp = delta_hid * [data; ones(1, nr_obser)]'; | |
elseif strcmp(net.training_type, 'ctc') | |
%% CTC training with BPTT | |
phonemes_id_modif = ones(2 * length(phonemes_id) + 1, 1); | |
for i = 1:length(phonemes_id) | |
phonemes_id_modif(2*i) = phonemes_id(i) + 1; % 1 is reserved for blanks so we swift all the others by one unit | |
end | |
% backward/forward algorithm | |
[alpha beta loglik] = ctc_fw_bw(output, phonemes_id_modif); | |
net.loglik = loglik; | |
% p(l/x) | |
gamma = alpha .* beta; | |
for s = 1:size(gamma, 1) | |
gamma(s,:) = gamma(s,:); | |
end | |
sum_gamma = sum(gamma, 1); | |
% computes the estimated signal | |
alphabet = [0 1:39]'; | |
nr_classes = length(alphabet); | |
pos = cell(1, nr_classes); | |
lab = false(nr_classes, length(phonemes_id_modif)); | |
est = zeros(nr_classes, nr_obser); | |
for k = 1:nr_classes | |
for m = 1:length(phonemes_id_modif) | |
if alphabet(k) == phonemes_id_modif(m) - 1 | |
lab(k, m) = true; % identifies the points of the input sequence modified where label k occurs | |
end | |
end | |
pos{k} = find(lab(k,:)); | |
if ~isempty(pos{k}) | |
aux = zeros(length(pos{k}), nr_obser); | |
for m = 1:length(pos{k}) | |
aux(m,:) = gamma(pos{k}(m),:); | |
end | |
est(k,:) = sum(aux, 1) ./ sum_gamma; | |
end | |
end | |
% BACKWARD PASS and gradient computation | |
delta_out = output - est; | |
d_W_out = delta_out * [hid_without_bias; ones(1, nr_obser)]'; | |
if strcmp(net.architecture, 'reservoir') | |
net.grad = d_W_out(:)'; | |
else | |
delta_hid = W_out(:, 1:end-1)' * delta_out; | |
delta_hid(:, end) = (1 - hid_without_bias(:, end).^2) .* delta_hid(:, end); | |
for steps = (nr_obser-1):-1:1 | |
delta_hid(:, steps) = (1 - hid_without_bias(:, steps).^2) .* (delta_hid(:, steps) + (W_rec' * delta_hid(:, steps + 1))); | |
end | |
d_W_rec = delta_hid(:, 2:nr_obser) * hid_without_bias(:, 1:end-1)'; | |
d_W_inp = delta_hid * [data; ones(1, nr_obser)]'; | |
end | |
% some plots | |
if ~mod(net.count, 10 * net.batch_size) && net.verbose == 1 % displays plots every x samples if verbose = 1 | |
subplot(511) | |
plot(alpha') | |
title('alphas') | |
subplot(512) | |
plot(beta') | |
title('betas') | |
subplot(513) | |
plot(output') | |
axis([0 Inf 0 1]) | |
title('net''s output') | |
subplot(514) | |
plot(est') | |
axis([0 Inf 0 1]) | |
title('prior signal used to train the network') | |
subplot(515) | |
plot(delta_out') | |
title('net''s error') | |
drawnow | |
end | |
end | |
if ~strcmp(net.training_type, 'bp') | |
% truncates gradients up to the length of the utterances | |
% d_W_out = d_W_out ./ nr_obser; | |
%if ~strcmp(net.architecture, 'reservoir') | |
%d_W_rec = d_W_rec ./ nr_obser; | |
%d_W_inp = d_W_inp ./ nr_obser; | |
% stores gradients | |
%net.grad = [d_W_inp(:)' d_W_rec(:)' d_W_out(:)']; | |
%end | |
else | |
% truncates gradients up to the length of the utterances | |
d_W_out = d_W_out ./ nr_obser; | |
d_W_inp = d_W_inp ./ nr_obser; | |
% stores gradients | |
net.grad = [d_W_inp(:)' d_W_out(:)']; | |
end | |
% WEIGHT UPDATE | |
net.W_out = W_out - net.learning_rate .* d_W_out - net.learning_rate * net.regularization_term * W_out + net.mom .* (W_out - net.W_out_old); | |
net.W_out_old = W_out; | |
if strcmp(net.architecture, 'ffnn') || strcmp(net.architecture, 'rnn') | |
if strcmp(net.architecture, 'rnn') | |
net.W_rec = W_rec - net.learning_rate .* d_W_rec - net.learning_rate * net.regularization_term * W_rec + net.mom .* (W_rec - net.W_rec_old); | |
net.W_rec_old = W_rec; | |
end | |
net.W_inp = W_inp - net.learning_rate .* d_W_inp - net.learning_rate * net.regularization_term * W_inp + net.mom .* (W_inp - net.W_inp_old); | |
net.W_inp_old = W_inp; | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment