Created
May 11, 2016 04:52
-
-
Save RicherMans/5d7319d735d418602342a0e0fa9c9327 to your computer and use it in GitHub Desktop.
An example of a feature extracting CNN using Torch. Packages needed are torch and audio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'optim' | |
require 'nn' | |
require 'audio' | |
require 'xlua' | |
-- Parsing of the commandline arguments | |
function cmdlineargs(arg) | |
local cmd = torch.CmdLine() | |
cmd:text() | |
cmd:text('Torch-7 CNN Training script') | |
cmd:text() | |
cmd:text('Options:') | |
------------ General options -------------------- | |
cmd:option('-data', '', 'Datafile which consists of every line dataset') | |
cmd:option('-manualSeed',2, 'Manually set RNG seed') | |
cmd:option('-mode','CPU','Default type for Network. Options: GPU | CPU') | |
------------- Training options -------------------- | |
cmd:option('-nepochs', 100, 'Number of total epochs to run') | |
cmd:option('-batchsize', 128, 'mini-batch size (1 = pure stochastic)') | |
---------- Optimization options ---------------------- | |
cmd:option('-lr', 0.8, 'learning rate') | |
cmd:option('-momentum', 0.9, 'momentum') | |
cmd:option('-weightdecay', 5e-4, 'weight decay') | |
cmd:option('-optim','','Optimization method default adadelta | sgd | adagrad') | |
---------- Model options ---------------------------------- | |
cmd:option('-framesize',400,'Framesize for CNN, 400 = 25 ms') | |
cmd:text() | |
local opt = cmd:parse(arg or {}) | |
return opt | |
end | |
local function loaddata(data) | |
local alldata = {} | |
for line in io.lines(data) do | |
alldata[#alldata+1] = audio.load(line) | |
end | |
return alldata | |
end | |
local function chunkdata(data,size) | |
local chunked = {} | |
local chunks = nil | |
for i=1,#data do | |
chunks = data[i]:split(size) | |
-- Do not use the last chunk since it is not of size "size" | |
for j=1,#chunks-1 do | |
chunked[#chunked + 1] = chunks[j] | |
end | |
end | |
return chunked | |
end | |
-- batches the given dataset and returns an iterator to every batch | |
local function batch(data,batchsize) | |
assert(data,"Didnt pass data as first arg") | |
assert(batchsize,"Didnt pass batchsize as 2nd arg") | |
local dataset = data | |
local numsamples = #dataset | |
local sample = 1 | |
local datax = dataset[1]:size(1) | |
local function iterator() | |
-- Stop iteration | |
if sample > numsamples then | |
return | |
end | |
-- If we have a batch which is smaller than the batchsize | |
local maxbatchsize = math.min(batchsize,numsamples - sample + 1) | |
local batcheddata = torch.Tensor(maxbatchsize,1,1,datax) | |
for i=1,maxbatchsize do | |
batcheddata[i] = dataset[sample] | |
sample = sample + 1 | |
end | |
return batcheddata,sample,numsamples | |
end | |
return iterator | |
end | |
opt = cmdlineargs(arg) | |
assert(opt.data ~= '',"Data not given, please pass -data") | |
assert(paths.filep(opt.data), "Data file cannt be found" ) | |
-- Define optimization parameters and method | |
local optimState = { | |
learningRate = opt.lr, | |
learningRateDecay = 1e-7, | |
momentum = opt.momentum, | |
weightDecay = opt.weightDecay | |
} | |
if opt.optim == 'adagrad' then | |
optimmethod = optim.adagrad | |
elseif opt.optim == 'sgd' then | |
optimmethod = optim.sgd | |
elseif opt.optim == 'rms' then | |
optimmethod = optim.rmsprop | |
elseif opt.optim == 'adam'then | |
optimmethod = optim.adam | |
else | |
optimmethod = optim.adadelta | |
end | |
-- network parameters | |
local hiddenSize = 100 | |
local nClass = 7 -- output classes | |
-- build simple CNN model | |
local timeconvolutions = {1,39} | |
local timestride = 400 | |
local timekernelwidth = 400 - timestride + 1 | |
local timekernelheight = 1 | |
-- The overall model | |
model = nn.Sequential() | |
local convolutionpart = nn.Sequential() | |
convolutionpart:add(nn.SpatialConvolution(timeconvolutions[1],timeconvolutions[2] ,timestride,1)) | |
convolutionpart:add(nn.SpatialBatchNormalization(timeconvolutions[2],1e-3)) | |
-- Pools to single dimension, removing all time invariance | |
convolutionpart:add(nn.SpatialMaxPooling(timekernelwidth,timekernelheight)) | |
convolutionpart:add(nn.ReLU()) | |
local classifier = nn.Sequential() | |
-- Reshape the input from the convolution to be one dimensional | |
classifier:add(nn.View(timeconvolutions[2])) | |
classifier:add(nn.Linear(timeconvolutions[2],hiddenSize)) | |
classifier:add(nn.BatchNormalization(hiddenSize)) | |
classifier:add(nn.ReLU()) | |
classifier:add(nn.Linear(hiddenSize,nClass)) | |
classifier:add(nn.LogSoftMax()) | |
-- Add both models toegether | |
model:add(convolutionpart) | |
model:add(classifier) | |
-- Print the current model | |
print(model) | |
-- Confusion matrix | |
trainconfusion = optim.ConfusionMatrix(nClass) | |
-- build criterion | |
criterion = nn.ClassNLLCriterion() | |
-- loading all the data | |
local alldata = loaddata(opt.data) | |
local chunkeddata = chunkdata(alldata,opt.framesize) | |
-- Get the model parameters | |
local parameters,gradParameters = model:getParameters() | |
-- training | |
local inputs, targets = torch.Tensor(), torch.Tensor() | |
for iteration = 1, opt.nepochs do | |
local iterator = batch(chunkeddata,opt.batchsize) | |
trainconfusion:zero() | |
local accerr = 0 | |
for batch,curid,endid in iterator do | |
xlua.progress(curid,endid) | |
-- Just set targets as being one class for now | |
targets:resize(batch:size(1)):fill(1) | |
local feval = function(x) | |
model:zeroGradParameters() | |
local outputs = model:forward(batch) | |
local err = criterion:forward(outputs, targets) | |
accerr = accerr + err | |
local gradOutputs = criterion:backward(outputs, targets) | |
-- The nonzero items in the outputs to correctly add the batches to | |
-- the confusionmatrix | |
trainconfusion:batchAdd(outputs,targets) | |
model:backward(batch, gradOutputs) | |
return err, gradParameters | |
end | |
optimmethod(feval,parameters,optimState) | |
end | |
trainconfusion:updateValids() | |
local total = trainconfusion.totalValid * 100 | |
local avgacc = trainconfusion.averageValid * 100 | |
print(string.format("Iter [%i/%i]:\t Total acc %.2f \t avg %2.f \t err %.3f ",iteration,opt.nepochs,total,avgacc,accerr)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment