Last active
November 21, 2018 22:58
-
-
Save masterdezign/158bab65e4e3df16d0a9355dfd0693a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Run a one dimensional gradient descent written in Clash [1], | |
-- a high level language that compiles to Verilog and VHDL. | |
-- Gradient descent can be described by a formula: | |
-- | |
-- a_n+1 = a_n - gamma * Grad F(a_n), | |
-- | |
-- where the constant `gamma` is what is referred to in deep learning as | |
-- the learning rate. | |
-- | |
-- [1] https://clash-lang.org/ | |
{-- Here is how we would run gradient descent in Haskell: | |
descent1D | |
:: (Double -> Double) | |
-> Int | |
-> Double | |
-> Double | |
-> [Double] | |
descent1D gradF iterN gamma x0 = take iterN (iterate (_descent gamma) x0) | |
where | |
_descent gamma' x = x - gamma' * gradF x | |
-- Suppose, we have a function F(x) = (x - 2)^2 + 1. | |
-- Therefore, Grad F(x) = 2 * (x - 2). | |
gradF_test x = 2 * (x - 2) | |
main = do | |
let testCase gamma = putStrLn $ "Descending with gamma=" ++ | |
show gamma ++ " " ++ | |
show (descent1D gradF_test 10 gamma 0.0) | |
mapM_ testCase [0.2, 0.6, 1.0] | |
-- Here is the output: | |
-- Descending with gamma=0.2 [0.0,0.8,1.28,1.568,1.7408000000000001,1.8444800000000001,1.9066880000000002,1.9440128,1.96640768,1.979844608] | |
-- Descending with gamma=0.6 [0.0,2.4,1.92,2.016,1.9968,2.00064,1.9998719999999999,2.0000256,1.99999488,2.000001024] | |
-- Descending with gamma=1.0 [0.0,4.0,0.0,4.0,0.0,4.0,0.0,4.0,0.0,4.0] | |
-- We can see that in the first case the learning rate gamma=0.2 is fine, | |
-- but the convergence is relatively slow. With gamma=0.6 there is a slight overshoot, | |
-- but the convergence is more efficient. | |
-- Finally, the learning rate gamma=1.0 is too large and the algorithm will never converge. | |
-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment