Created
December 3, 2012 11:04
-
-
Save osdf/4194221 to your computer and use it in GitHub Desktop.
First steps with torch -- logistic regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'nn' | |
require 'optim' | |
trainset = torch.load("mnist.t7/train_32x32.t7", "ascii") | |
trainset.data = trainset.data:resize(60000, 32*32) | |
trainset.data = trainset.data:double() | |
test = torch.load("mnist.t7/test_32x32.t7", "ascii") | |
test.data = test.data:resize(10000, 32*32) | |
test.data = test.data:double() | |
function trainset:size() return trainset.data:size(1) end | |
function test:size() return test.data:size(1) end | |
sz = trainset.data:size(2) | |
model = nn.Sequential() | |
model:add(nn.Linear(sz, 10)) | |
model:add(nn.LogSoftMax()) | |
criterion = nn.ClassNLLCriterion() | |
criterion.sizeAverage = false | |
params, grad_params = model:getParameters() | |
params:zero() | |
config = {learningRate = 1e-5, momentum = 0.9} | |
btchsz = 100 | |
epoch = 1 | |
function train() | |
epoch = epoch or 1 | |
print("Epoch:", epoch) | |
for t = 1, trainset:size(), btchsz do | |
local feval = function(x) | |
if x ~= params then | |
params:copy(x) | |
end | |
grad_params:zero() | |
local inputs = trainset.data[{{t,t+btchsz-1},{}}] | |
local targets = trainset.labels[{{t,t+btchsz-1}}] | |
local output = model:forward(inputs) | |
local err = criterion:forward(output, targets) | |
local dfdo = criterion:backward(output, targets) | |
model:backward(inputs, dfdo) | |
grad_params:div(inputs:size(1)) | |
return err, grad_params | |
end | |
optim.sgd(feval, params, config) | |
end | |
tst(trainset) | |
print() | |
epoch = epoch + 1 | |
end | |
function tst(test) | |
local output = model:forward(test.data) | |
local err = criterion:forward(output, test.labels) | |
max, idx = output:max(2) | |
local zo = torch.ne(idx:byte(), test.labels):sum() | |
print("NLL per sample:", err/test:size()) | |
print("Wrong classifications:", zo) | |
end | |
for j = 1,10 do | |
train() | |
end | |
print("Test set evaluation") | |
tst(test) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'torch' | |
require 'nn' | |
require 'optim' | |
trainset = torch.load("mnist.t7/train_32x32.t7", "ascii") | |
trainset.data = trainset.data:resize(60000, 32*32) | |
trainset.data = trainset.data:double() | |
test = torch.load("mnist.t7/test_32x32.t7", "ascii") | |
test.data = test.data:resize(10000, 32*32) | |
test.data = test.data:double() | |
function trainset:size() return trainset.data:size(1) end | |
function test:size() return test.data:size(1) end | |
sz = trainset.data:size(2) | |
model = nn.Sequential() | |
model:add(nn.Linear(sz, 10)) | |
model:add(nn.LogSoftMax()) | |
criterion = nn.ClassNLLCriterion() | |
criterion.sizeAverage = false | |
params, grad_params = model:getParameters() | |
params:zero() | |
config = {learningRate = 1e-5, momentum = 0.9} | |
btchsz = 100 | |
epoch = 1 | |
function train() | |
epoch = epoch or 1 | |
print("Epoch:", epoch) | |
for t = 1, trainset:size(), btchsz do | |
local inputs = {} | |
local targets = {} | |
for i = t, math.min(t+btchsz-1, trainset:size()) do | |
local input = trainset.data[i] | |
local target = trainset.labels[i] | |
table.insert(inputs, input) | |
table.insert(targets, target) | |
end | |
local feval = function(x) | |
if x ~= params then | |
params:copy(x) | |
end | |
grad_params:zero() | |
local f = 0 | |
for i = 1,#inputs do | |
local output = model:forward(inputs[i]) | |
local err = criterion:forward(output, targets[i]) | |
f = f + err | |
local df_do = criterion:backward(output, targets[i]) | |
model:backward(inputs[i], df_do) | |
end | |
grad_params:div(#inputs) | |
return f, grad_params | |
end | |
optim.sgd(feval, params, config) | |
end | |
tst(trainset) | |
epoch = epoch + 1 | |
print() | |
end | |
function tst(test) | |
local output = model:forward(test.data) | |
local err = criterion:forward(output, test.labels) | |
max, idx = output:max(2) | |
local zo = torch.ne(idx:byte(), test.labels):sum() | |
print("NLL per sample:", err/test:size()) | |
print("Wrong classifications:", zo) | |
end | |
for j = 1,10 do | |
train() | |
end | |
print("Test set evaluation") | |
tst(test) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment