Created
October 23, 2020 10:09
-
-
Save rebordao/0c24ded683af99a9700adcdacf138655 to your computer and use it in GitHub Desktop.
Speech recognizer for TIMIT samples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function batch_nn(architecture, nr_hidden_units, learning_rate, mom, regularization_term, verbose) | |
% This function implements a speech recognizer for TIMIT samples. | |
% Architecture: a FFNN trained by Back Propagation; a RNN trained by BPTT or CTC; or a Reservoir trained via ctc. | |
% | |
% INPUTS: | |
% architecture can be ffnn, rnn or reservoir | |
% nr_hidden_units is the nr of hidden units | |
% learning_rate is the learning rate for the weights' update | |
% mom is the momemtum rate | |
% regularization_term defines the regularization | |
% verbose defines the level of training information that is displayed (0 or 1) | |
% | |
% Antonio Rebordao, 2011 | |
% ----------------------------------------------------------------------- | |
clc; close all; % some cleaning | |
if ~strcmp(architecture, 'ffnn') && ~strcmp(architecture, 'rnn') && ~strcmp(architecture, 'reservoir') | |
error('Define the type of arquitecture. It can be ffnn, rnn or reservoir.') | |
end | |
if strcmp(architecture, 'ffnn') | |
training_type = 'bp'; | |
elseif strcmp(architecture, 'reservoir') | |
training_type = 'ctc'; | |
end | |
%% EDITABLE VARIABLES | |
% ------------------------- | |
std_deviation = 0.1;; % standard deviation for the initial weights | |
max_epochs = 100; % nr max of epochs | |
batch_size = 200; % nr of timesteps for each batch | |
if strcmp(architecture, 'rnn') | |
training_type = 'bptt'; % this value can be bptt or ctc | |
end | |
% RESERVOIR PARAMETERS | |
if strcmp(architecture, 'reservoir') | |
spectral_radius = 0.4; | |
input_scale_factor = 1.25; | |
net.leak_rate = 0.4; | |
end | |
% ------------------------- | |
%% LOADS AND PRE-PROCESSES THE DATA | |
%~ % loads the training data and reads input dimension | |
fprintf(1, 'Loading data...\n') | |
load(fullfile('~/truecrypt1/Dropbox/experiments', 'timit_mfcc.mat')) | |
[nr_seqs inp_dim] = size(train_inputs); | |
% opens the dictionary file that contains the classes/phonemes for the task | |
fid = fopen('~/truecrypt1/datasets/timit2/TIMIT_MFCC/dictionary_39.txt', 'r'); | |
if fid < 0 | |
error('Could not open the dictionary file. Check if the file exists at location: \n%s\n', fullfile(inp_dir, 'dictionary_39.txt')); | |
else | |
alphabet = textscan(fid, '%s'); | |
alphabet = char(cellstr(alphabet{1})); | |
end | |
fclose(fid); | |
% reads output dimension | |
out_dim = size(alphabet, 1); % nr of classes | |
% computes/stores the mean and standard deviation of the training set | |
mea = mean(train_inputs, 1); | |
sta = std(train_inputs, 0, 1); | |
%% CREATES THE NETWORK/TOPOLOGY | |
% creates weights | |
net.W_inp = rand(nr_hidden_units, inp_dim + 1) .* (2 * std_deviation) - std_deviation; | |
net.W_rec = rand(nr_hidden_units, nr_hidden_units) .* (2 * std_deviation) - std_deviation; | |
if strcmp(architecture, 'reservoir') | |
net.W_inp = input_scale_factor .* net.W_inp; | |
eigen = max(abs(eig(net.W_rec))); | |
net.W_rec = net.W_rec .* spectral_radius/eigen; | |
end | |
if strcmp(training_type, 'ctc') | |
net.W_out = rand(out_dim + 1, nr_hidden_units + 1) .* (2 * std_deviation) - std_deviation; % the extra row is for the blank unit | |
else | |
net.W_out = rand(out_dim, nr_hidden_units + 1) .* (2 * std_deviation) - std_deviation; | |
end | |
% initialization of the matrices used in the momentum | |
net.W_out_old = net.W_out; | |
net.W_rec_old = net.W_rec; | |
net.W_inp_old = net.W_inp; | |
% stores some variables | |
net.training_type = training_type; | |
net.architecture = architecture; | |
net.learning_rate = learning_rate; | |
net.mom = mom; | |
net.regularization_term = regularization_term; | |
net.verbose = verbose; | |
net.batch_size = batch_size; | |
% opens a file where the results will be writen | |
path = fullfile('~/truecrypt1/Dropbox/experiments/timit/', ['results_', datestr(now, 'mmdd_HHMMSS'),'.txt']); | |
fid = fopen(path, 'a'); | |
% writes some info into the file | |
fprintf(fid, 'nr hidden units: %g\nlearning rate: %g\nmomentum rate: %g\nregularization term: %g\n', nr_hidden_units, learning_rate, mom, regularization_term); | |
fprintf(fid, '\narchitecture: %s\ntraining_type: %s\n', architecture, training_type); | |
fprintf(fid, 'nr epochs: %g\nbatch_size: %g\nverbose: %g\n\n', max_epochs, batch_size, verbose); | |
%% TRAINING | |
epoch = 1; | |
convergence = false; | |
validation_errors = zeros(1, max_epochs * (ceil(nr_seqs/batch_size) + 1)); | |
all_val_id = 1; % counts the nr of validation samples of all epochs | |
tic | |
while epoch <= max_epochs && ~convergence | |
fprintf(1, '\nTRAINING PHASE, epoch %g\n\n', epoch) | |
if ~strcmp(training_type, 'ctc') | |
% shuffles frames id at every epoch | |
frames = randperm(nr_seqs); | |
end | |
% computes the gradient for every batch | |
ct = 0; % controls the pass over the data for each epoch | |
%~val_id = 1; % counts the validation samples per epoch | |
batch_train_ct = 0; % counts the nr of batches per epoch | |
train_err = 0; | |
%~ validation_err = 0; | |
logliks = 0; | |
while ct + batch_size <= nr_seqs | |
% loads batches | |
if strcmp(training_type, 'ctc') | |
data = train_inputs(ct + 1: ct + batch_size, :); | |
targets = train_targets(ct + 1: ct + batch_size, :); | |
else | |
data = train_inputs(frames(ct + 1: ct + batch_size), :); | |
targets = train_targets(frames(ct + 1: ct + batch_size), :); | |
end | |
ct = ct + batch_size; | |
net.count = ct; | |
batch_train_ct = batch_train_ct + 1; | |
% normalizes data | |
if strcmp(architecture, 'reservoir') | |
% explicit re-scaling/normalization of the mfcc features (refer to Fabian's paper) | |
data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75; | |
data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25; | |
data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25; | |
data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50; | |
data(27,:) = 4.97 * (data(27,:) - mean(data(27,:))); | |
data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25; | |
else | |
% normalizes data with the mean and std of the training set | |
data = (data - repmat(mea, size(data, 1), 1)) ./ repmat(sta, size(data, 1), 1); | |
end | |
% transposing the data into columns = timesteps | |
data = data'; | |
targets = targets'; | |
% TRAINS THE NETWORK | |
[net hid_without_bias output] = trains_nn(net, data, targets, alphabet); | |
% confirms the gradient | |
%~ [gradd] = grad_by_finite_diff(net, data, targets); % to confirm the gradient | |
%~ [net.grad(1:200); gradd(1:200)]' | |
%~ [net.grad(end-200:end); gradd(end-200:end)]' | |
%~ return | |
% TRAINING ERROR | |
if strcmp(training_type, 'bp') || strcmp(training_type, 'bptt') | |
% training error | |
[val, idx1] = max(output); | |
[val, idx2] = max(targets); | |
train_err = train_err + sum(idx1 ~= idx2)/length(idx1); | |
% validation error | |
if ~mod(ct, 5 * batch_size) % computes/prints the validation error every x samples | |
fprintf(1, 'epoch %2.0f (%3.1f%% completed), train_error: %2.1f%%\n', epoch, 100*ct/nr_seqs, 100*train_err/batch_train_ct) | |
% picks a random test sample | |
%~load(fullfile(inp_dir, files(sample,1).name)) | |
%~ | |
%~% normalizes data with the mean and std of the training set | |
%~data = (data - repmat(mea, 1, size(data, 2))) ./ repmat(sta, 1, size(data, 2)); | |
%~% forward pass | |
%~[hid_without_bias output] = forward_pass(net, data); | |
%~% computes error | |
%~[val, idx1] = max(output); | |
%~[val, idx2] = max(targets); | |
%~validation_err = validation_err + sum(idx1 ~= idx2)/size(idx1, 2); | |
%~validation_errors(all_val_id) = validation_err/val_id; | |
%~fprintf(1, 'validation_error: %2.1f%%\n', 100*validation_err/val_id) | |
%~val_id = val_id + 1; | |
%~all_val_id = all_val_id + 1; | |
end | |
elseif strcmp(training_type, 'ctc') | |
% reads phoneme location | |
[val phonemes_id] = max(targets, [], 2); | |
% training error for CTC | |
[lbl tot] = phoneme_err(output, phonemes_id); | |
train_err = train_err + lbl/tot; | |
% stores loglik | |
logliks = logliks + net.loglik; | |
% validation error | |
if ~mod(ct, 5 * batch_size) && train_err/batch_train_ct <= 1 % computes/prints the validation error every x samples | |
fprintf(1, 'epoch %2.0f (%3.1f%% completed), train_error: %2.1f%%, loglik %2.1f\n', epoch, 100*ct/nr_seqs, 100*train_err/batch_train_ct, logliks/batch_train_ct) | |
%~% picks a random test sample | |
%~aux = false; | |
%~while aux == false | |
%~sample = ceil(rand*(nr_seqs-1)); | |
%~if strcmp(files(sample,1).name(1), 't') && strcmp(files(sample,1).name(1:4), 'test') | |
%~aux = true; | |
%~end | |
%~end | |
%~ | |
%~load(fullfile(inp_dir, files(sample,1).name)) | |
%~ | |
%~% explicit normalization of the mfcc features (following Fabian's paper) | |
%~if strcmp(architecture, 'reservoir') | |
%~data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75; | |
%~data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25; | |
%~data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25; | |
%~data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50; | |
%~data(27,:) = 4.97 * (data(27,:) - mean(data(27,:))); | |
%~data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25; | |
%~else | |
%~% normalizes data with the mean and std of the training set | |
%~data = (data - repmat(mea, 1, size(data, 2))) ./ repmat(sta, 1, size(data, 2)); | |
%~end | |
%~ | |
%~[hid_without_bias output] = forward_pass(net, data); | |
%~ | |
%~% computes edit distance | |
%~[lbl tot] = phoneme_err(output, phonemes, alphabet); | |
%~ | |
%~if lbl/tot <= 1 | |
%~validation_err = validation_err + lbl/tot; | |
%~validation_errors(all_val_id) = validation_err/val_id; | |
%~fprintf(1, 'validation_error: %2.1f%%\n', 100*validation_err/val_id) | |
%~val_id = val_id + 1; | |
%~all_val_id = all_val_id + 1; | |
%~end | |
%~ | |
elseif ~mod(ct, 5 * batch_size) % in these cases train_err/ct > 1 (e.g. the net is still not trained) | |
fprintf(1, 'epoch %2.0f (%3.1f%% completed), loglik %2.1f\n', epoch, 100*ct/nr_seqs, logliks/batch_train_ct) | |
end | |
end | |
end | |
epoch = epoch + 1; | |
%% TESTING | |
fprintf(1, '\nTESTING PHASE\n') | |
% computes test error | |
ct = 0; | |
test_err = 0; | |
batch_test_ct = 0; | |
while ct + batch_size <= size(test_inputs, 1) | |
% loads batches | |
data = test_inputs(ct + 1: ct + batch_size, :); | |
targets = test_targets(ct + 1: ct + batch_size, :); | |
ct = ct + batch_size; | |
batch_test_ct = batch_test_ct + 1; | |
% normalizes data | |
if strcmp(architecture, 'reservoir') | |
% explicit re-scaling/normalization of the mfcc features (refer to Fabian's paper) | |
data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75; | |
data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25; | |
data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25; | |
data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50; | |
data(27,:) = 4.97 * (data(27,:) - mean(data(27,:))); | |
data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25; | |
else | |
% normalizes data with the mean and std of the training set | |
data = (data - repmat(mea, size(data, 1), 1)) ./ repmat(sta, size(data, 1), 1); | |
end | |
% forward pass | |
[hid_without_bias output] = forward_pass(net, data'); | |
% testing error for BP or BPTT | |
if strcmp(training_type, 'bp') || strcmp(training_type, 'bptt') | |
% computes error | |
[val idx1] = max(output, [], 1); | |
[val idx2] = max(targets, [], 2); | |
test_err = test_err + sum(idx1 ~= idx2') / length(idx2); | |
else | |
% for CTC | |
[lbl tot] = phoneme_err(output, phonemes_id); | |
test_err = test_err + lbl/tot; | |
end | |
end | |
fprintf(1, 'epoch %g, mean test_error: %2.1f%%\n', epoch - 1, 100*test_err/batch_test_ct) % prints on the monitor | |
if strcmp(training_type, 'ctc') | |
fprintf(fid, 'epoch %g, mean test_error: %2.1f%%, mean train_error: %2.1f%%, loglik %2.1f\n', epoch - 1, 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1), logliks/batch_train_ct); % writes into file | |
else | |
fprintf(fid, 'epoch %g, mean test_error: %2.1f%%, mean train_error: %2.1f%%\n', epoch - 1, 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1)); % writes into file | |
end | |
end | |
time = toc/60; % time in minutes | |
%% prints some info on the monitor | |
fprintf(1, 'mean test_error: %2.1f%%\nmean train_error: %2.1f%%\n', 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1)) | |
fprintf(1, 'learning rate: %g\nmomentum rate: %g\nregularization term: %g\n', learning_rate, mom, regularization_term) | |
fprintf(1, '\nnr epochs: %g/%g\nnr hidden units: %g\n', epoch - 1, max_epochs, nr_hidden_units) | |
fprintf(1, 'architecture: %s\ntraining_type: %s\ntraining_time: %1.0f min\n', architecture, training_type, time) | |
% prints into file | |
fprintf(fid, 'training_time: %1.0f min', time); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment