Skip to content

Instantly share code, notes, and snippets.

@joschuck
Created July 14, 2017 16:16
Show Gist options
  • Save joschuck/69b77d86c50f4d71d0cb824a242bf284 to your computer and use it in GitHub Desktop.
Save joschuck/69b77d86c50f4d71d0cb824a242bf284 to your computer and use it in GitHub Desktop.
Adaboost for lua with Torch
Adaboost = {}
Adaboost.__index = Adaboost
function Adaboost:create(training_set, ground_truth)
local adaboost = {}
setmetatable(adaboost,Adaboost)
adaboost.training_set = training_set
adaboost.ground_truth = ground_truth
adaboost.N = training_set:size()[1]
adaboost.weights = torch.ones(adaboost.N):div(adaboost.N)
adaboost.RULES = {}
adaboost.ALPHA = {}
return adaboost
end
function Adaboost:set_rule(func, test)
local errors = torch.zeros(self.N)
for i = 1, self.training_set:size()[1] do
if (self.ground_truth[i] ~= func(self.training_set[i])) then
errors[i] = 1
end
end
local e = torch.cmul(self.weights, errors):sum()
if test then
return e
end
local alpha = 0.5 * math.log((1-e)/e)
local w = torch.zeros(self.N)
for i = 1,self.N do
if (errors[i] == 1) then
w[i] = self.weights[i] * math.exp(alpha)
else
w[i] = self.weights[i] * math.exp(-alpha)
end
end
self.weights = w:div(w:sum())
self.RULES[#self.RULES+1] = func
self.ALPHA[#self.ALPHA+1] = alpha
end
function Adaboost:evaluate()
for i = 1, self.training_set:size()[1] do
local hx = torch.Tensor(#self.RULES)
for j = 1,#self.RULES do
hx[j] = self.ALPHA[j] * self.RULES[j]( self.training_set[i] )
end
print( i, math.sign(self.ground_truth[i]) == math.sign(hx:sum()))
end
end
local training_set = torch.Tensor({
{1, 2 },
{1, 4 },
{2.5,5.5},
{3.5,6.5},
{4, 5.4},
{2, 1 },
{2, 4 },
{3.5,3.5},
{5, 2 },
{5, 5.5 }
})
local ground_truth = torch.Tensor({1,1,1,1,1,-1,-1,-1,-1,-1})
function bool_to_num(bool)
if bool then
return 1
end
return 0
end
function math.sign(x)
if x<0 then
return -1
elseif x>0 then
return 1
else
return 0
end
end
local m = Adaboost:create(training_set, ground_truth)
local rule_1 = function (x) return 2 * bool_to_num((x[1] < 1.5)) - 1 end
local rule_2 = function (x) return 2 * bool_to_num((x[1] < 4.5)) - 1 end
local rule_3 = function (x) return 2 * bool_to_num((x[2] > 5)) - 1 end
m:set_rule(rule_1)
m:set_rule(rule_2)
m:set_rule(rule_3)
m:evaluate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment