Skip to content

Instantly share code, notes, and snippets.

@yanatan16
Last active January 4, 2016 03:19
Show Gist options
  • Save yanatan16/8560908 to your computer and use it in GitHub Desktop.
Save yanatan16/8560908 to your computer and use it in GitHub Desktop.
Rewriting haskell-spsa
-- The primary monadic type to be passed around
type StateSPSA = State SPSA
-- | Set the loss function
setLoss :: LossFn -> StateSPSA ()
setLoss loss = modify (\spsa -> spsa { lossFn = loss })
-- | Push a stopping criteria onto SPSA
pushStopCrit :: StoppingCriteria -> StateSPSA ()
pushStopCrit sc = modify (\spsa -> spsa { stoppingCrits = sc : stoppingCrits spsa })
newMakeSPSA :: Int -> (Int, Int, Double, Double) -> StateSPSA ()
newMakeSPSA seed (iter, dim, a, c) = do
setLoss rosenbrock
pushStopCrit (Iterations iter)
setPerturbation (bernoulli seed dim)
semiautomaticTuning iter a c
makeSPSA :: IO SPSA
makeSPSA = do
let (ak,ck) = semiautomaticTuning 0.0001 0.05
mkUnconstrainedSPSA rosenbrock ak ck 10
-- | Run the SPSA optimization algorithm
runSPSA' :: Vector Double -> StateSPSA (Vector Double)
runSPSA' t = do
t' <- singleIteration t
stop <- checkStop t t'
incrementIteration
if stop then return t' else runSPSA' t'
-- | Perform a single iteration of SPSA
singleIteration :: Vector Double -> StateSPSA (Vector Double)
singleIteration t = do
(a, c, d) <- peelAll
lossF <- getLoss
constrainF <- getConstraint
let cd = c `scale` d
let ya = lossF (t + cd)
let yb = lossF (t - cd)
let grad = ((ya - yb) / 2) `scaleRecip` cd
return $ constrainF (t - (a `scale` grad))
-- | Exported runSPSA function to extract the SPSA type
runSPSA :: StateSPSA a -> Vector Double -> Vector Double
runSPSA st t0 = evalState checkAndRunSpsa defaultSPSA
where checkAndRunSpsa = st >> checkSPSA t0 >> runSPSA' t0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment