Last active
September 8, 2015 04:15
-
-
Save bshillingford/315540c787c46e46f13d to your computer and use it in GitHub Desktop.
See test.lua for verification the new code works
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ An implementation of L-BFGS, heavily inspired by minFunc (Mark Schmidt) | |
This implementation of L-BFGS relies on a user-provided line | |
search function (state.lineSearch). If this function is not | |
provided, then a simple learningRate is used to produce fixed | |
size steps. Fixed size steps are much less costly than line | |
searches, and can be useful for stochastic problems. | |
The learning rate is used even when a line search is provided. | |
This is also useful for large-scale stochastic problems, where | |
opfunc is a noisy approximation of f(x). In that case, the learning | |
rate allows a reduction of confidence in the step size. | |
ARGS: | |
- `opfunc` : a function that takes a single input (X), the point of | |
evaluation, and returns f(X) and df/dX | |
- `x` : the initial point | |
- `state` : a table describing the state of the optimizer; after each | |
call the state is modified | |
- `state.maxIter` : Maximum number of iterations allowed (1 skips end-of-loop checks) | |
- `state.maxEval` : Maximum number of function evaluations | |
- `state.tolFun` : Termination tolerance on the first-order optimality | |
- `state.tolX` : Termination tol on progress in terms of func/param changes | |
- `state.lineSearch` : A line search function | |
- `state.learningRate` : If no line search provided, then a fixed step size is used | |
RETURN: | |
- `x*` : the new `x` vector, at the optimal point | |
- `f` : a table of all function values: | |
`f[1]` is the value of the function before any optimization and | |
`f[#f]` is the final fully optimized value, at `x*` | |
(Clement Farabet, 2012) | |
]] | |
function optim.lbfgs2(opfunc, x, config, state) | |
-- get/update state | |
local config = config or {} | |
local state = state or config | |
local maxIter = tonumber(config.maxIter) or 20 | |
local maxEval = tonumber(config.maxEval) or maxIter*1.25 | |
local tolFun = config.tolFun or 1e-5 | |
local tolX = config.tolX or 1e-9 | |
local nCorrection = config.nCorrection or 100 | |
local lineSearch = config.lineSearch | |
local lineSearchOpts = config.lineSearchOptions | |
local learningRate = config.learningRate or 1 | |
local isverbose = config.verbose or false | |
state.funcEval = state.funcEval or 0 | |
state.nIter = state.nIter or 0 | |
-- verbose function | |
local verbose | |
if isverbose then | |
verbose = function(...) print('<optim.lbfgs> ', ...) end | |
else | |
verbose = function() end | |
end | |
-- import some functions | |
local abs = math.abs | |
local min = math.min | |
-- evaluate initial f(x) and df/dx | |
local f,g = opfunc(x) | |
local f_hist = {f} | |
local currentFuncEval = 1 | |
state.funcEval = state.funcEval + 1 | |
local p = g:size(1) | |
-- check optimality of initial point | |
state.tmp1 = state.tmp1 or g.new(g:size()):zero(); local tmp1 = state.tmp1 | |
tmp1:copy(g):abs() | |
if tmp1:sum() <= tolFun then | |
-- optimality condition below tolFun | |
verbose('optimality condition below tolFun') | |
return x,f_hist | |
end | |
-- reusable buffers for y's and s's, and their histories | |
state.dir_bufs = state.dir_bufs or g.new(nCorrection+1, p):split(1) | |
state.stp_bufs = state.stp_bufs or g.new(nCorrection+1, p):split(1) | |
for i=1,#state.dir_bufs do | |
state.dir_bufs[i] = state.dir_bufs[i]:squeeze(1) | |
state.stp_bufs[i] = state.stp_bufs[i]:squeeze(1) | |
end | |
-- variables cached in state (for tracing) | |
local d = state.d | |
local t = state.t | |
local old_dirs = state.old_dirs | |
local old_stps = state.old_stps | |
local Hdiag = state.Hdiag | |
local g_old = state.g_old | |
local f_old = state.f_old | |
-- optimize for a max of maxIter iterations | |
local nIter = 0 | |
while nIter < maxIter do | |
-- keep track of nb of iterations | |
nIter = nIter + 1 | |
state.nIter = state.nIter + 1 | |
------------------------------------------------------------ | |
-- compute gradient descent direction | |
------------------------------------------------------------ | |
if state.nIter == 1 then | |
d = g:clone():mul(-1) -- -g | |
old_dirs = {} | |
old_stps = {} | |
Hdiag = 1 | |
else | |
-- do lbfgs update (update memory) | |
local y = table.remove(state.dir_bufs) -- pop | |
local s = table.remove(state.stp_bufs) | |
y:copy(g):add(-1, g_old) -- g - g_old | |
s:copy(d):mul(t) -- d*t | |
local ys = y:dot(s) -- y*s | |
if ys > 1e-10 then | |
-- updating memory | |
if #old_dirs == nCorrection then | |
-- shift history by one (limited-memory) | |
local removed1 = table.remove(old_dirs, 1) | |
local removed2 = table.remove(old_stps, 1) | |
table.insert(state.dir_bufs, removed1) | |
table.insert(state.stp_bufs, removed2) | |
end | |
-- store new direction/step | |
table.insert(old_dirs, s) | |
table.insert(old_stps, y) | |
-- update scale of initial Hessian approximation | |
Hdiag = ys / y:dot(y) -- (y*y) | |
else | |
-- put y and s back into the buffer pool | |
table.insert(state.dir_bufs, y) | |
table.insert(state.stp_bufs, s) | |
end | |
-- compute the approximate (L-BFGS) inverse Hessian | |
-- multiplied by the gradient | |
local k = #old_dirs -- REMOVEME: k=nCor, nCorrection=maxCorrections | |
-- need to be accessed element-by-element, so don't re-type tensor: | |
state.ro = state.ro or torch.Tensor(nCorrection); local ro = state.ro | |
for i = 1,k do | |
ro[i] = 1 / old_stps[i]:dot(old_dirs[i]) | |
end | |
-- iteration in L-BFGS loop collapsed to use just one buffer | |
local q = tmp1 -- reuse tmp1 for the q buffer | |
-- need to be accessed element-by-element, so don't re-type tensor: | |
state.al = state.al or torch.zeros(nCorrection) local al = state.al | |
q:copy(g):mul(-1) -- -g | |
for i = k,1,-1 do | |
al[i] = old_dirs[i]:dot(q) * ro[i] | |
q:add(-al[i], old_stps[i]) | |
end | |
-- multiply by initial Hessian | |
r = d -- share the same buffer, since we don't need the old d | |
r:copy(q):mul(Hdiag) -- q[1] * Hdiag | |
for i = 1,k do | |
local be_i = old_stps[i]:dot(r) * ro[i] | |
r:add(al[i]-be_i, old_dirs[i]) | |
end | |
-- final direction is in r/d (same object) | |
end | |
g_old = g_old or g:clone() | |
g_old:copy(g) | |
f_old = f | |
------------------------------------------------------------ | |
-- compute step length | |
------------------------------------------------------------ | |
-- directional derivative | |
local gtd = g:dot(d) -- g * d | |
-- check that progress can be made along that direction | |
if gtd > -tolX then | |
break | |
end | |
-- reset initial guess for step size | |
if state.nIter == 1 then | |
tmp1:copy(g):abs() | |
t = min(1,1/tmp1:sum()) * learningRate | |
else | |
t = learningRate | |
end | |
-- optional line search: user function | |
local lsFuncEval = 0 | |
if lineSearch and type(lineSearch) == 'function' then | |
-- perform line search, using user function | |
f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts) | |
table.insert(f_hist, f) | |
else | |
-- no line search, simply move with fixed-step | |
x:add(t,d) | |
if nIter ~= maxIter then | |
-- re-evaluate function only if not in last iteration | |
-- the reason we do this: in a stochastic setting, | |
-- no use to re-evaluate that function here | |
f,g = opfunc(x) | |
lsFuncEval = 1 | |
table.insert(f_hist, f) | |
end | |
end | |
-- update func eval | |
currentFuncEval = currentFuncEval + lsFuncEval | |
state.funcEval = state.funcEval + lsFuncEval | |
------------------------------------------------------------ | |
-- check conditions | |
------------------------------------------------------------ | |
if nIter == maxIter then | |
-- no use to run tests | |
verbose('reached max number of iterations') | |
break | |
end | |
if currentFuncEval >= maxEval then | |
-- max nb of function evals | |
verbose('max nb of function evals') | |
break | |
end | |
tmp1:copy(g):abs() | |
if tmp1:sum() <= tolFun then | |
-- check optimality | |
verbose('optimality condition below tolFun') | |
break | |
end | |
tmp1:copy(d):mul(t):abs() | |
if tmp1:sum() <= tolX then | |
-- step size below tolX | |
verbose('step size below tolX') | |
break | |
end | |
if abs(f-f_old) < tolX then | |
-- function value changing less than tolX | |
verbose('function value changing less than tolX') | |
break | |
end | |
end | |
-- save state | |
state.old_dirs = old_dirs | |
state.old_stps = old_stps | |
state.Hdiag = Hdiag | |
state.g_old = g_old | |
state.f_old = f_old | |
state.t = t | |
state.d = d | |
-- return optimal x, and history of f(x) | |
return x,f_hist,currentFuncEval | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ An implementation of L-BFGS, heavily inspired by minFunc (Mark Schmidt) | |
This implementation of L-BFGS relies on a user-provided line | |
search function (state.lineSearch). If this function is not | |
provided, then a simple learningRate is used to produce fixed | |
size steps. Fixed size steps are much less costly than line | |
searches, and can be useful for stochastic problems. | |
The learning rate is used even when a line search is provided. | |
This is also useful for large-scale stochastic problems, where | |
opfunc is a noisy approximation of f(x). In that case, the learning | |
rate allows a reduction of confidence in the step size. | |
ARGS: | |
- `opfunc` : a function that takes a single input (X), the point of | |
evaluation, and returns f(X) and df/dX | |
- `x` : the initial point | |
- `state` : a table describing the state of the optimizer; after each | |
call the state is modified | |
- `state.maxIter` : Maximum number of iterations allowed | |
- `state.maxEval` : Maximum number of function evaluations | |
- `state.tolFun` : Termination tolerance on the first-order optimality | |
- `state.tolX` : Termination tol on progress in terms of func/param changes | |
- `state.lineSearch` : A line search function | |
- `state.learningRate` : If no line search provided, then a fixed step size is used | |
RETURN: | |
- `x*` : the new `x` vector, at the optimal point | |
- `f` : a table of all function values: | |
`f[1]` is the value of the function before any optimization and | |
`f[#f]` is the final fully optimized value, at `x*` | |
(Clement Farabet, 2012) | |
]] | |
function optim.lbfgs(opfunc, x, config, state) | |
-- get/update state | |
local config = config or {} | |
local state = state or config | |
local maxIter = tonumber(config.maxIter) or 20 | |
local maxEval = tonumber(config.maxEval) or maxIter*1.25 | |
local tolFun = config.tolFun or 1e-5 | |
local tolX = config.tolX or 1e-9 | |
local nCorrection = config.nCorrection or 100 | |
local lineSearch = config.lineSearch | |
local lineSearchOpts = config.lineSearchOptions | |
local learningRate = config.learningRate or 1 | |
local isverbose = config.verbose or false | |
state.funcEval = state.funcEval or 0 | |
state.nIter = state.nIter or 0 | |
-- verbose function | |
local function verbose(...) | |
if isverbose then print('<optim.lbfgs> ', ...) end | |
end | |
-- import some functions | |
local zeros = torch.zeros | |
local randn = torch.randn | |
local append = table.insert | |
local abs = math.abs | |
local min = math.min | |
-- evaluate initial f(x) and df/dx | |
local f,g = opfunc(x) | |
local f_hist = {f} | |
local currentFuncEval = 1 | |
state.funcEval = state.funcEval + 1 | |
-- check optimality of initial point | |
state.tmp1 = state.abs_g or zeros(g:size()); local tmp1 = state.tmp1 | |
tmp1:copy(g):abs() | |
if tmp1:sum() <= tolFun then | |
-- optimality condition below tolFun | |
verbose('optimality condition below tolFun') | |
return x,f_hist | |
end | |
-- variables cached in state (for tracing) | |
local d = state.d | |
local t = state.t | |
local old_dirs = state.old_dirs | |
local old_stps = state.old_stps | |
local Hdiag = state.Hdiag | |
local g_old = state.g_old | |
local f_old = state.f_old | |
-- optimize for a max of maxIter iterations | |
local nIter = 0 | |
while nIter < maxIter do | |
-- keep track of nb of iterations | |
nIter = nIter + 1 | |
state.nIter = state.nIter + 1 | |
------------------------------------------------------------ | |
-- compute gradient descent direction | |
------------------------------------------------------------ | |
if state.nIter == 1 then | |
d = g:clone():mul(-1) -- -g | |
old_dirs = {} | |
old_stps = {} | |
Hdiag = 1 | |
else | |
-- do lbfgs update (update memory) | |
local y = g:clone():add(-1, g_old) -- g - g_old | |
local s = d:clone():mul(t) -- d*t | |
local ys = y:dot(s) -- y*s | |
if ys > 1e-10 then | |
-- updating memory | |
if #old_dirs == nCorrection then | |
-- shift history by one (limited-memory) | |
local prev_old_dirs = old_dirs | |
local prev_old_stps = old_stps | |
old_dirs = {} | |
old_stps = {} | |
for i = 2,#prev_old_dirs do | |
append(old_dirs, prev_old_dirs[i]) | |
append(old_stps, prev_old_stps[i]) | |
end | |
end | |
-- store new direction/step | |
append(old_dirs, s) | |
append(old_stps, y) | |
-- update scale of initial Hessian approximation | |
Hdiag = ys / y:dot(y) -- (y*y) | |
-- cleanup | |
collectgarbage() | |
end | |
-- compute the approximate (L-BFGS) inverse Hessian | |
-- multiplied by the gradient | |
local p = g:size(1) | |
local k = #old_dirs | |
state.ro = state.ro or zeros(nCorrection); local ro = state.ro | |
for i = 1,k do | |
ro[i] = 1 / old_stps[i]:dot(old_dirs[i]) | |
end | |
state.q = state.q or zeros(nCorrection+1,p):typeAs(g) | |
local q = state.q | |
state.r = state.r or zeros(nCorrection+1,p):typeAs(g) | |
local r = state.r | |
state.al = state.al or zeros(nCorrection):typeAs(g) | |
local al = state.al | |
state.be = state.be or zeros(nCorrection):typeAs(g) | |
local be = state.be | |
q[k+1] = g:clone():mul(-1) -- -g | |
for i = k,1,-1 do | |
al[i] = old_dirs[i]:dot(q[i+1]) * ro[i] | |
q[i] = q[i+1] | |
q[i]:add(-al[i], old_stps[i]) | |
end | |
-- multiply by initial Hessian | |
r[1] = q[1]:clone():mul(Hdiag) -- q[1] * Hdiag | |
for i = 1,k do | |
be[i] = old_stps[i]:dot(r[i]) * ro[i] | |
r[i+1] = r[i] | |
r[i+1]:add((al[i] - be[i]), old_dirs[i]) | |
end | |
-- final direction: | |
d:copy(r[k+1]) | |
end | |
g_old = g:clone() | |
f_old = f | |
------------------------------------------------------------ | |
-- compute step length | |
------------------------------------------------------------ | |
-- directional derivative | |
local gtd = g:dot(d) -- g * d | |
-- check that progress can be made along that direction | |
if gtd > -tolX then | |
break | |
end | |
-- reset initial guess for step size | |
if state.nIter == 1 then | |
tmp1:copy(g):abs() | |
t = min(1,1/tmp1:sum()) * learningRate | |
else | |
t = learningRate | |
end | |
-- optional line search: user function | |
local lsFuncEval = 0 | |
if lineSearch and type(lineSearch) == 'function' then | |
-- perform line search, using user function | |
f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts) | |
append(f_hist, f) | |
else | |
-- no line search, simply move with fixed-step | |
x:add(t,d) | |
if nIter ~= maxIter then | |
-- re-evaluate function only if not in last iteration | |
-- the reason we do this: in a stochastic setting, | |
-- no use to re-evaluate that function here | |
f,g = opfunc(x) | |
lsFuncEval = 1 | |
append(f_hist, f) | |
end | |
end | |
-- update func eval | |
currentFuncEval = currentFuncEval + lsFuncEval | |
state.funcEval = state.funcEval + lsFuncEval | |
------------------------------------------------------------ | |
-- check conditions | |
------------------------------------------------------------ | |
if nIter == maxIter then | |
-- no use to run tests | |
verbose('reached max number of iterations') | |
break | |
end | |
if currentFuncEval >= maxEval then | |
-- max nb of function evals | |
verbose('max nb of function evals') | |
break | |
end | |
tmp1:copy(g):abs() | |
if tmp1:sum() <= tolFun then | |
-- check optimality | |
verbose('optimality condition below tolFun') | |
break | |
end | |
tmp1:copy(d):mul(t):abs() | |
if tmp1:sum() <= tolX then | |
-- step size below tolX | |
verbose('step size below tolX') | |
break | |
end | |
if abs(f-f_old) < tolX then | |
-- function value changing less than tolX | |
verbose('function value changing less than tolX') | |
break | |
end | |
end | |
-- save state | |
state.old_dirs = old_dirs | |
state.old_stps = old_stps | |
state.Hdiag = Hdiag | |
state.g_old = g_old | |
state.f_old = f_old | |
state.t = t | |
state.d = d | |
-- return optimal x, and history of f(x) | |
return x,f_hist,currentFuncEval | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
Verifies the correctness of the updated L-BFGS implementation wrt the old one. | |
To run this test: | |
1. put original L-BFGS in lbfgs_orig.lua | |
2. put new L-BFGS in lbfgs_new.lua, and rename optim.lbfgs to optim.lbfgs2 | |
at the top of the file | |
Alternatively, clone this gist and the files will be in the right place. | |
]] | |
optim = {} | |
require 'lbfgs_orig' | |
lbfgs_original = optim.lbfgs | |
require 'lbfgs_new' | |
lbfgs2 = optim.lbfgs2 | |
local COUNT = 19 | |
local N = 123 | |
torch.manualSeed(1) | |
local grads = {} | |
local losses = {} | |
for i=1,COUNT do | |
grads[#grads+1] = torch.randn(N) | |
losses[i] = (losses[i-1] or 1) / 2 | |
end | |
function new_feval() | |
local i = 0 | |
return function() | |
i = i + 1 | |
return losses[i], grads[i] | |
end | |
end | |
local x = torch.zeros(N) | |
local feval = new_feval() | |
local optim_state = {maxIter=1,nCorrection=6} | |
local orig = {} | |
for i=1,COUNT do | |
lbfgs_original(feval, x, optim_state) | |
orig[#orig + 1] = x:clone() | |
end | |
x:zero() | |
local feval = new_feval() | |
local optim_state = {maxIter=1,nCorrection=6} | |
local new = {} | |
for i=1,COUNT do | |
lbfgs2(feval, x, optim_state) | |
new[#new + 1] = x:clone() | |
end | |
for i=1,COUNT do | |
print((orig[i] - new[i]):abs():sum()) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment