Created
June 23, 2016 12:43
-
-
Save simgt/c990c755d7b97f75a677cca259238f3c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local torch =require 'torch' | |
local nn =require 'nn' | |
local rnn =require 'rnn' | |
local gnuplot = require 'gnuplot' | |
torch.setnumthreads(4) | |
print('number of threads: ' .. torch.getnumthreads()) | |
batchSize = 16 | |
rho = 16 -- sequence length | |
lr = 0.01 | |
-- | |
-- Model | |
-- | |
inputSize = 1 | |
hiddenSize = 32 | |
outputSize = 1 | |
local r = nn.Recurrent( | |
hiddenSize, -- start | |
nn.Linear(inputSize, hiddenSize), -- input | |
nn.Linear(hiddenSize, hiddenSize), -- feedback | |
nn.Sigmoid(), -- transfer | |
rho | |
) | |
local rnn = nn.Sequential() | |
:add(r) | |
:add(nn.Linear(hiddenSize, outputSize)) | |
rnn = nn.Sequencer(rnn) | |
-- load a model previously trained | |
--rnn = torch.load('sine-waves-model.dat', 'ascii', true) | |
print(rnn) | |
criterion = nn.SequencerCriterion(nn.MSECriterion()) | |
-- | |
-- Dataset | |
-- | |
local numSamples = 1024 | |
local numPeriods = 10 | |
local t = torch.linspace(0, numPeriods * 2 * math.pi, numSamples) | |
local input = torch.Tensor(numSamples, inputSize) | |
local output = torch.Tensor(numSamples, outputSize) | |
input:select(2, 1):copy(torch.sin(t)) | |
--input:select(2, 2):copy(torch.sin(t/2)) | |
output:select(2, 1):copy(torch.sin(t/2)) | |
--output:select(2, 2):copy(torch.sin(t*2)) | |
-- | |
-- Training | |
-- | |
local it = 1 | |
while true do | |
offsets = torch.LongTensor(batchSize) | |
for i=1,batchSize do | |
offsets[i] = math.ceil(math.random()*input:size(1)) | |
end | |
for a = 1, 2000 do | |
-- create a batch of sequences of rho time-steps | |
local x, y = {}, {} | |
for step = 1, rho do | |
x[step] = input:index(1, offsets) | |
y[step] = output:index(1, offsets) | |
-- incement indices | |
offsets = offsets + 1 | |
for j = 1, batchSize do | |
if offsets[j] > numSamples then | |
offsets[j] = 1 | |
end | |
end | |
end | |
-- forward the sequence | |
local z = rnn:forward(x) | |
local err = criterion:forward(z, y) | |
print(string.format("[%d] err = %f", it, err / rho)) | |
-- backward the sequence (i.e. BPTT) in reverse order of forward calls | |
rnn:zeroGradParameters() | |
local gz = criterion:backward(z, y) | |
rnn:backward(x, gz) | |
-- update | |
rnn:updateParameters(lr) | |
it = it + 1 | |
end | |
-- save the model | |
print("Saving...") | |
torch.save('sine-waves-model.dat', rnn, 'ascii', true) | |
-- test on the full sequence | |
local z = rnn:forward(input) | |
gnuplot.pngfigure('sine-waves-test.png') | |
gnuplot.plot( | |
{'input', t, input:select(2, 1), '-'}, | |
{'truth', t, output:select(2, 1), '-'}, | |
{'estimate', t, z:select(2, 1), '-'} | |
) | |
gnuplot.plotflush() | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment