Last active
January 4, 2016 03:19
-
-
Save yanatan16/8560908 to your computer and use it in GitHub Desktop.
Rewriting haskell-spsa
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- The primary monadic type to be passed around | |
type StateSPSA = State SPSA | |
-- | Set the loss function | |
setLoss :: LossFn -> StateSPSA () | |
setLoss loss = modify (\spsa -> spsa { lossFn = loss }) | |
-- | Push a stopping criteria onto SPSA | |
pushStopCrit :: StoppingCriteria -> StateSPSA () | |
pushStopCrit sc = modify (\spsa -> spsa { stoppingCrits = sc : stoppingCrits spsa }) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
newMakeSPSA :: Int -> (Int, Int, Double, Double) -> StateSPSA () | |
newMakeSPSA seed (iter, dim, a, c) = do | |
setLoss rosenbrock | |
pushStopCrit (Iterations iter) | |
setPerturbation (bernoulli seed dim) | |
semiautomaticTuning iter a c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
makeSPSA :: IO SPSA | |
makeSPSA = do | |
let (ak,ck) = semiautomaticTuning 0.0001 0.05 | |
mkUnconstrainedSPSA rosenbrock ak ck 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | Run the SPSA optimization algorithm | |
runSPSA' :: Vector Double -> StateSPSA (Vector Double) | |
runSPSA' t = do | |
t' <- singleIteration t | |
stop <- checkStop t t' | |
incrementIteration | |
if stop then return t' else runSPSA' t' | |
-- | Perform a single iteration of SPSA | |
singleIteration :: Vector Double -> StateSPSA (Vector Double) | |
singleIteration t = do | |
(a, c, d) <- peelAll | |
lossF <- getLoss | |
constrainF <- getConstraint | |
let cd = c `scale` d | |
let ya = lossF (t + cd) | |
let yb = lossF (t - cd) | |
let grad = ((ya - yb) / 2) `scaleRecip` cd | |
return $ constrainF (t - (a `scale` grad)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | Exported runSPSA function to extract the SPSA type | |
runSPSA :: StateSPSA a -> Vector Double -> Vector Double | |
runSPSA st t0 = evalState checkAndRunSpsa defaultSPSA | |
where checkAndRunSpsa = st >> checkSPSA t0 >> runSPSA' t0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment