Created
October 24, 2016 14:10
-
-
Save RicherMans/c86a171035035e6e8c4cb1e9a98a3320 to your computer and use it in GitHub Desktop.
torch example of using audiodataload
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'optim' | |
require 'nn' | |
require 'audio' | |
require 'xlua' | |
local adl = require 'audiodataload' | |
-- Parsing of the commandline arguments | |
function cmdlineargs(arg) | |
local cmd = torch.CmdLine() | |
cmd:text() | |
cmd:text('Torch-7 CNN Training script') | |
cmd:text() | |
cmd:text('Options:') | |
------------ General options -------------------- | |
cmd:option('-data', '', 'Datafile which consists of every line dataset') | |
cmd:option('-manualSeed',2, 'Manually set RNG seed') | |
cmd:option('-mode','CPU','Default type for Network. Options: GPU | CPU') | |
------------- Training options -------------------- | |
cmd:option('-nepochs', 100, 'Number of total epochs to run') | |
cmd:option('-batchsize', 128, 'mini-batch size (1 = pure stochastic)') | |
---------- Optimization options ---------------------- | |
cmd:option('-lr', 0.8, 'learning rate') | |
cmd:option('-momentum', 0.9, 'momentum') | |
cmd:option('-weightdecay', 5e-4, 'weight decay') | |
cmd:option('-optim','sgd','Optimization method default adadelta | sgd | adagrad') | |
---------- Model options ---------------------------------- | |
cmd:option('-framesize',400,'Framesize for CNN, 400 = 25 ms') | |
cmd:text() | |
local opt = cmd:parse(arg or {}) | |
return opt | |
end | |
opt = cmdlineargs(arg) | |
assert(opt.data ~= '',"Data not given, please pass -data") | |
assert(paths.filep(opt.data), "Data file cannt be found" ) | |
-- Define optimization parameters and method | |
local optimState = { | |
learningRate = opt.lr, | |
learningRateDecay = 1e-7, | |
momentum = opt.momentum, | |
weightDecay = opt.weightDecay | |
} | |
local optimmethod = optim[opt.optim] | |
if optimmethod == nil then | |
optimmethod = opt.sgd | |
end | |
torch.setdefaulttensortype('torch.FloatTensor') | |
-- loading all the data | |
local htkdataloader = adl.HtkDataloader(opt.data) | |
local dataloader = adl.JoinedDataloader{module=htkdataloader,dirpath="mypath",cachesize=1000} | |
dataloader:shuffle() | |
-- network parameters | |
local hiddenSize = 100 | |
local nClass = dataloader:nClasses():squeeze() -- output classes | |
local function linearblock(inp,out) | |
local m = nn.Sequential() | |
m:add(nn.Linear(inp,out)) | |
m:add(nn.BatchNormalization(out)) | |
m:add(nn.ReLU(true)) | |
return m | |
end | |
-- The overall model | |
model = nn.Sequential() | |
model:add(linearblock(39,hiddenSize)) | |
model:add(linearblock(hiddenSize,hiddenSize)) | |
model:add(nn.Linear(hiddenSize,nClass)) | |
model:add(nn.LogSoftMax()) | |
-- Print the current model | |
print(model) | |
-- Confusion matrix | |
trainconfusion = optim.ConfusionMatrix(nClass) | |
-- build criterion | |
criterion = nn.ClassNLLCriterion() | |
-- Get the model parameters | |
local parameters,gradParameters = model:getParameters() | |
-- training | |
local inputs, targets = torch.Tensor(), torch.Tensor() | |
for iteration = 1, opt.nepochs do | |
trainconfusion:zero() | |
local accerr = 0 | |
for curid,endid,input,target in dataloader:sampleiterator(opt.batchsize,nil,true) do | |
xlua.progress(curid,endid) | |
target =target:squeeze() | |
-- Just set targets as being one class for now | |
local feval = function(x) | |
model:zeroGradParameters() | |
local outputs = model:forward(input) | |
local err = criterion:forward(outputs, target) | |
accerr = accerr + err | |
local gradOutputs = criterion:backward(outputs, target) | |
-- The nonzero items in the outputs to correctly add the batches to | |
-- the confusionmatrix | |
trainconfusion:batchAdd(outputs,target) | |
model:backward(input, gradOutputs) | |
return err, gradParameters | |
end | |
optimmethod(feval,parameters,optimState) | |
end | |
trainconfusion:updateValids() | |
local total = trainconfusion.totalValid * 100 | |
local avgacc = trainconfusion.averageValid * 100 | |
print(string.format("Iter [%i/%i]:\t Total acc %.2f \t avg %2.f \t err %.3f ",iteration,opt.nepochs,total,avgacc,accerr)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment