Last active
February 26, 2017 03:56
-
-
Save crowsonkb/8da6cc4bfc5e99565ea7f897700a0bc0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ An experimental quasi-Newton optimizer. | |
Incorporates Hessian damping, momentum, and per-feature learning rate scaling. | |
Also implements optional polynomial-decay averaging (similar to ASGD). | |
ARGS: | |
- 'opfunc' : a function that takes a single input (X), the point | |
of a evaluation, and returns f(X) and df/dX | |
- 'x' : the initial point | |
- 'config` : a table with configuration parameters for the optimizer | |
- 'config.averagingDecay' : if >= 0, averaging decay exponent. if < 0, | |
disables averaging | |
- 'config.epsilon' : for numerical stability | |
- 'config.learningRate' : learning rate | |
- 'config.momentum' : momentum | |
- 'config.nCorrection' : the maximum number of L-BFGS corrections | |
- 'config.phi' : Hessian damping | |
RETURN: | |
- `x` : the new x vector | |
- `f(x)` : the function, evaluated after the update | |
- `average` : the averaged parameter vector | |
(Katherine Crowson, 2016) | |
]] | |
function dmsqn(opfunc, x, config, state) | |
-- Configuration | |
local config = config or {} | |
local state = state or config | |
local always_div_g2 = config.alwaysDivG2 or true | |
local avg_decay = config.averagingDecay or -1 | |
local eps = config.epsilon or 1e-8 | |
local lr = config.learningRate or 1e-4 | |
local momentum = config.momentum or 0.9 | |
local nCorrection = config.nCorrection or 10 | |
local phi = config.phi or 0.2 | |
-- Initialization | |
state.t = state.t or 0 | |
-- L-BFGS memory | |
state.sk = state.sk or {} | |
state.yk = state.yk or {} | |
-- Gradient first moment accumulator | |
state.g1 = state.g1 or x.new(x:size()):zero() | |
-- Gradient second moment accumulator | |
state.g2 = state.g2 or x.new(x:size()):fill(eps) | |
-- Parameter vector first moment accumulator | |
state.p1 = state.p1 or x.new(x:size()):zero() | |
-- Reusable buffers for s and y | |
state.s = state.s or x.new(x:size()) | |
state.y = state.y or x.new(x:size()) | |
local s, y = state.s, state.y | |
-- Reusable temporary buffer | |
state.tmp = state.tmp or x.new(x:size()) | |
local tmp = state.tmp | |
-- Reusable buffers for s's and y's | |
if not state.s_bufs then | |
state.s_bufs = state.s_bufs or x.new(nCorrection, x:nElement()):split(1) | |
state.y_bufs = state.y_bufs or x.new(nCorrection, x:nElement()):split(1) | |
for i=1,#state.s_bufs do | |
state.s_bufs[i] = state.s_bufs[i]:squeeze(1) | |
state.y_bufs[i] = state.y_bufs[i]:squeeze(1) | |
end | |
end | |
-- First step: set initial state | |
if not state.g then | |
_,state.g = opfunc(x) | |
state.g1:add(state.g) | |
state.g2:addcmul(state.g, state.g) | |
end | |
-- Decay first moment of gradient | |
state.g1:mul(momentum) | |
-- Compute step with L-BFGS two-loop recursion | |
s:add(state.g1, state.g) -- Nesterov momentum | |
local k = #state.sk | |
local rho = torch.zeros(nCorrection) | |
for i = 1,k do | |
rho[i] = 1 / state.sk[i]:dot(state.yk[i]) | |
end | |
local alpha = torch.zeros(nCorrection) | |
for i = k,1,-1 do | |
alpha[i] = state.sk[i]:dot(s) * rho[i] | |
s:add(-alpha[i], state.yk[i]) | |
end | |
if not always_div_g2 and k > 0 then | |
local sy = state.sk[k]:dot(state.yk[k]) | |
local yy = state.yk[k]:dot(state.yk[k]) | |
s:mul(sy / yy) | |
else | |
s:cdiv(tmp:sqrt(state.g2)) | |
end | |
for i = 1,k do | |
local beta = state.yk[i]:dot(s) * rho[i] | |
s:add(alpha[i] - beta, state.sk[i]) | |
end | |
-- Two-loop recursion done: take step and update moments | |
s:mul(-lr) | |
--print(state.t, tmp:abs(s):mean()) | |
x:add(s) | |
fx, g = opfunc(x) | |
state.g1:add(g) | |
state.g2:addcmul(g, g) | |
-- Compute y | |
y:add(g, -1, state.g) -- y = new gradient - old gradient | |
y:mul(1-phi):add(phi, s) -- Hessian damping | |
y:cmul(tmp:sqrt(state.g2)) -- Scale by Adagrad scaling matrix | |
-- Store gradient | |
state.g:copy(g) | |
-- Store curvature pair | |
if #state.sk == nCorrection then | |
-- Shift history by one | |
local removed_s = table.remove(state.sk, 1) | |
local removed_y = table.remove(state.yk, 1) | |
table.insert(state.s_bufs, removed_s) | |
table.insert(state.y_bufs, removed_y) | |
end | |
table.insert(state.sk, table.remove(state.s_bufs):copy(s)) | |
table.insert(state.yk, table.remove(state.y_bufs):copy(y)) | |
-- Return x*, f(x) after step | |
state.t = state.t + 1 | |
if avg_decay < 0 then | |
return x, {fx}, x | |
end | |
-- Polynomial-decay averaging | |
local weight = (1+avg_decay) / (state.t+avg_decay) | |
state.p1:mul(1-weight):add(weight, x) | |
return x, {fx}, state.p1 | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment