Skip to content

Instantly share code, notes, and snippets.

@rebordao
Created October 23, 2020 10:09
Show Gist options
  • Save rebordao/0c24ded683af99a9700adcdacf138655 to your computer and use it in GitHub Desktop.
Save rebordao/0c24ded683af99a9700adcdacf138655 to your computer and use it in GitHub Desktop.
Speech recognizer for TIMIT samples
function batch_nn(architecture, nr_hidden_units, learning_rate, mom, regularization_term, verbose)
% This function implements a speech recognizer for TIMIT samples.
% Architecture: a FFNN trained by Back Propagation; a RNN trained by BPTT or CTC; or a Reservoir trained via ctc.
%
% INPUTS:
% architecture can be ffnn, rnn or reservoir
% nr_hidden_units is the nr of hidden units
% learning_rate is the learning rate for the weights' update
% mom is the momemtum rate
% regularization_term defines the regularization
% verbose defines the level of training information that is displayed (0 or 1)
%
% Antonio Rebordao, 2011
% -----------------------------------------------------------------------
clc; close all; % some cleaning
if ~strcmp(architecture, 'ffnn') && ~strcmp(architecture, 'rnn') && ~strcmp(architecture, 'reservoir')
error('Define the type of arquitecture. It can be ffnn, rnn or reservoir.')
end
if strcmp(architecture, 'ffnn')
training_type = 'bp';
elseif strcmp(architecture, 'reservoir')
training_type = 'ctc';
end
%% EDITABLE VARIABLES
% -------------------------
std_deviation = 0.1;; % standard deviation for the initial weights
max_epochs = 100; % nr max of epochs
batch_size = 200; % nr of timesteps for each batch
if strcmp(architecture, 'rnn')
training_type = 'bptt'; % this value can be bptt or ctc
end
% RESERVOIR PARAMETERS
if strcmp(architecture, 'reservoir')
spectral_radius = 0.4;
input_scale_factor = 1.25;
net.leak_rate = 0.4;
end
% -------------------------
%% LOADS AND PRE-PROCESSES THE DATA
%~ % loads the training data and reads input dimension
fprintf(1, 'Loading data...\n')
load(fullfile('~/truecrypt1/Dropbox/experiments', 'timit_mfcc.mat'))
[nr_seqs inp_dim] = size(train_inputs);
% opens the dictionary file that contains the classes/phonemes for the task
fid = fopen('~/truecrypt1/datasets/timit2/TIMIT_MFCC/dictionary_39.txt', 'r');
if fid < 0
error('Could not open the dictionary file. Check if the file exists at location: \n%s\n', fullfile(inp_dir, 'dictionary_39.txt'));
else
alphabet = textscan(fid, '%s');
alphabet = char(cellstr(alphabet{1}));
end
fclose(fid);
% reads output dimension
out_dim = size(alphabet, 1); % nr of classes
% computes/stores the mean and standard deviation of the training set
mea = mean(train_inputs, 1);
sta = std(train_inputs, 0, 1);
%% CREATES THE NETWORK/TOPOLOGY
% creates weights
net.W_inp = rand(nr_hidden_units, inp_dim + 1) .* (2 * std_deviation) - std_deviation;
net.W_rec = rand(nr_hidden_units, nr_hidden_units) .* (2 * std_deviation) - std_deviation;
if strcmp(architecture, 'reservoir')
net.W_inp = input_scale_factor .* net.W_inp;
eigen = max(abs(eig(net.W_rec)));
net.W_rec = net.W_rec .* spectral_radius/eigen;
end
if strcmp(training_type, 'ctc')
net.W_out = rand(out_dim + 1, nr_hidden_units + 1) .* (2 * std_deviation) - std_deviation; % the extra row is for the blank unit
else
net.W_out = rand(out_dim, nr_hidden_units + 1) .* (2 * std_deviation) - std_deviation;
end
% initialization of the matrices used in the momentum
net.W_out_old = net.W_out;
net.W_rec_old = net.W_rec;
net.W_inp_old = net.W_inp;
% stores some variables
net.training_type = training_type;
net.architecture = architecture;
net.learning_rate = learning_rate;
net.mom = mom;
net.regularization_term = regularization_term;
net.verbose = verbose;
net.batch_size = batch_size;
% opens a file where the results will be writen
path = fullfile('~/truecrypt1/Dropbox/experiments/timit/', ['results_', datestr(now, 'mmdd_HHMMSS'),'.txt']);
fid = fopen(path, 'a');
% writes some info into the file
fprintf(fid, 'nr hidden units: %g\nlearning rate: %g\nmomentum rate: %g\nregularization term: %g\n', nr_hidden_units, learning_rate, mom, regularization_term);
fprintf(fid, '\narchitecture: %s\ntraining_type: %s\n', architecture, training_type);
fprintf(fid, 'nr epochs: %g\nbatch_size: %g\nverbose: %g\n\n', max_epochs, batch_size, verbose);
%% TRAINING
epoch = 1;
convergence = false;
validation_errors = zeros(1, max_epochs * (ceil(nr_seqs/batch_size) + 1));
all_val_id = 1; % counts the nr of validation samples of all epochs
tic
while epoch <= max_epochs && ~convergence
fprintf(1, '\nTRAINING PHASE, epoch %g\n\n', epoch)
if ~strcmp(training_type, 'ctc')
% shuffles frames id at every epoch
frames = randperm(nr_seqs);
end
% computes the gradient for every batch
ct = 0; % controls the pass over the data for each epoch
%~val_id = 1; % counts the validation samples per epoch
batch_train_ct = 0; % counts the nr of batches per epoch
train_err = 0;
%~ validation_err = 0;
logliks = 0;
while ct + batch_size <= nr_seqs
% loads batches
if strcmp(training_type, 'ctc')
data = train_inputs(ct + 1: ct + batch_size, :);
targets = train_targets(ct + 1: ct + batch_size, :);
else
data = train_inputs(frames(ct + 1: ct + batch_size), :);
targets = train_targets(frames(ct + 1: ct + batch_size), :);
end
ct = ct + batch_size;
net.count = ct;
batch_train_ct = batch_train_ct + 1;
% normalizes data
if strcmp(architecture, 'reservoir')
% explicit re-scaling/normalization of the mfcc features (refer to Fabian's paper)
data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75;
data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25;
data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25;
data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50;
data(27,:) = 4.97 * (data(27,:) - mean(data(27,:)));
data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25;
else
% normalizes data with the mean and std of the training set
data = (data - repmat(mea, size(data, 1), 1)) ./ repmat(sta, size(data, 1), 1);
end
% transposing the data into columns = timesteps
data = data';
targets = targets';
% TRAINS THE NETWORK
[net hid_without_bias output] = trains_nn(net, data, targets, alphabet);
% confirms the gradient
%~ [gradd] = grad_by_finite_diff(net, data, targets); % to confirm the gradient
%~ [net.grad(1:200); gradd(1:200)]'
%~ [net.grad(end-200:end); gradd(end-200:end)]'
%~ return
% TRAINING ERROR
if strcmp(training_type, 'bp') || strcmp(training_type, 'bptt')
% training error
[val, idx1] = max(output);
[val, idx2] = max(targets);
train_err = train_err + sum(idx1 ~= idx2)/length(idx1);
% validation error
if ~mod(ct, 5 * batch_size) % computes/prints the validation error every x samples
fprintf(1, 'epoch %2.0f (%3.1f%% completed), train_error: %2.1f%%\n', epoch, 100*ct/nr_seqs, 100*train_err/batch_train_ct)
% picks a random test sample
%~load(fullfile(inp_dir, files(sample,1).name))
%~
%~% normalizes data with the mean and std of the training set
%~data = (data - repmat(mea, 1, size(data, 2))) ./ repmat(sta, 1, size(data, 2));
%~% forward pass
%~[hid_without_bias output] = forward_pass(net, data);
%~% computes error
%~[val, idx1] = max(output);
%~[val, idx2] = max(targets);
%~validation_err = validation_err + sum(idx1 ~= idx2)/size(idx1, 2);
%~validation_errors(all_val_id) = validation_err/val_id;
%~fprintf(1, 'validation_error: %2.1f%%\n', 100*validation_err/val_id)
%~val_id = val_id + 1;
%~all_val_id = all_val_id + 1;
end
elseif strcmp(training_type, 'ctc')
% reads phoneme location
[val phonemes_id] = max(targets, [], 2);
% training error for CTC
[lbl tot] = phoneme_err(output, phonemes_id);
train_err = train_err + lbl/tot;
% stores loglik
logliks = logliks + net.loglik;
% validation error
if ~mod(ct, 5 * batch_size) && train_err/batch_train_ct <= 1 % computes/prints the validation error every x samples
fprintf(1, 'epoch %2.0f (%3.1f%% completed), train_error: %2.1f%%, loglik %2.1f\n', epoch, 100*ct/nr_seqs, 100*train_err/batch_train_ct, logliks/batch_train_ct)
%~% picks a random test sample
%~aux = false;
%~while aux == false
%~sample = ceil(rand*(nr_seqs-1));
%~if strcmp(files(sample,1).name(1), 't') && strcmp(files(sample,1).name(1:4), 'test')
%~aux = true;
%~end
%~end
%~
%~load(fullfile(inp_dir, files(sample,1).name))
%~
%~% explicit normalization of the mfcc features (following Fabian's paper)
%~if strcmp(architecture, 'reservoir')
%~data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75;
%~data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25;
%~data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25;
%~data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50;
%~data(27,:) = 4.97 * (data(27,:) - mean(data(27,:)));
%~data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25;
%~else
%~% normalizes data with the mean and std of the training set
%~data = (data - repmat(mea, 1, size(data, 2))) ./ repmat(sta, 1, size(data, 2));
%~end
%~
%~[hid_without_bias output] = forward_pass(net, data);
%~
%~% computes edit distance
%~[lbl tot] = phoneme_err(output, phonemes, alphabet);
%~
%~if lbl/tot <= 1
%~validation_err = validation_err + lbl/tot;
%~validation_errors(all_val_id) = validation_err/val_id;
%~fprintf(1, 'validation_error: %2.1f%%\n', 100*validation_err/val_id)
%~val_id = val_id + 1;
%~all_val_id = all_val_id + 1;
%~end
%~
elseif ~mod(ct, 5 * batch_size) % in these cases train_err/ct > 1 (e.g. the net is still not trained)
fprintf(1, 'epoch %2.0f (%3.1f%% completed), loglik %2.1f\n', epoch, 100*ct/nr_seqs, logliks/batch_train_ct)
end
end
end
epoch = epoch + 1;
%% TESTING
fprintf(1, '\nTESTING PHASE\n')
% computes test error
ct = 0;
test_err = 0;
batch_test_ct = 0;
while ct + batch_size <= size(test_inputs, 1)
% loads batches
data = test_inputs(ct + 1: ct + batch_size, :);
targets = test_targets(ct + 1: ct + batch_size, :);
ct = ct + batch_size;
batch_test_ct = batch_test_ct + 1;
% normalizes data
if strcmp(architecture, 'reservoir')
% explicit re-scaling/normalization of the mfcc features (refer to Fabian's paper)
data(1,:) = 0.27 * (data(1,:) - mean(data(1,:))) * 1.75;
data(2:13,:) = 0.10 * (data(2:13,:) - repmat(mean(data(2:13,:), 2), 1, size(data, 2))) * 1.25;
data(14,:) = 1.77 * (data(14,:) - mean(data(14,:))) * 1.25;
data(15:26,:) = 0.61 * (data(15:26,:) - repmat(mean(data(15:26,:), 2), 1, size(data, 2))) * 0.50;
data(27,:) = 4.97 * (data(27,:) - mean(data(27,:)));
data(28:39,:) = 1.75 * (data(28:39,:) - repmat(mean(data(28:39,:), 2), 1, size(data, 2))) * 0.25;
else
% normalizes data with the mean and std of the training set
data = (data - repmat(mea, size(data, 1), 1)) ./ repmat(sta, size(data, 1), 1);
end
% forward pass
[hid_without_bias output] = forward_pass(net, data');
% testing error for BP or BPTT
if strcmp(training_type, 'bp') || strcmp(training_type, 'bptt')
% computes error
[val idx1] = max(output, [], 1);
[val idx2] = max(targets, [], 2);
test_err = test_err + sum(idx1 ~= idx2') / length(idx2);
else
% for CTC
[lbl tot] = phoneme_err(output, phonemes_id);
test_err = test_err + lbl/tot;
end
end
fprintf(1, 'epoch %g, mean test_error: %2.1f%%\n', epoch - 1, 100*test_err/batch_test_ct) % prints on the monitor
if strcmp(training_type, 'ctc')
fprintf(fid, 'epoch %g, mean test_error: %2.1f%%, mean train_error: %2.1f%%, loglik %2.1f\n', epoch - 1, 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1), logliks/batch_train_ct); % writes into file
else
fprintf(fid, 'epoch %g, mean test_error: %2.1f%%, mean train_error: %2.1f%%\n', epoch - 1, 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1)); % writes into file
end
end
time = toc/60; % time in minutes
%% prints some info on the monitor
fprintf(1, 'mean test_error: %2.1f%%\nmean train_error: %2.1f%%\n', 100*test_err/batch_test_ct, 100*train_err/(batch_train_ct - 1))
fprintf(1, 'learning rate: %g\nmomentum rate: %g\nregularization term: %g\n', learning_rate, mom, regularization_term)
fprintf(1, '\nnr epochs: %g/%g\nnr hidden units: %g\n', epoch - 1, max_epochs, nr_hidden_units)
fprintf(1, 'architecture: %s\ntraining_type: %s\ntraining_time: %1.0f min\n', architecture, training_type, time)
% prints into file
fprintf(fid, 'training_time: %1.0f min', time);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment