Last active
August 29, 2015 14:08
-
-
Save kputnam/a904c05f7560d6085a23 to your computer and use it in GitHub Desktop.
Gradient Descent Linear Regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Linear | |
where | |
import Prelude hiding (error) | |
import Control.Arrow | |
import Data.List | |
import Data.Vector (Vector) | |
import qualified Data.Vector as V | |
-- | Independent variables from a single observation | |
type Observation | |
= Vector Double | |
-- | Dependent variable | |
type Response | |
= Double | |
-- | Vector of coefficients to a linear equation | |
type Model | |
= Vector Double | |
type Error | |
= Double | |
-- | Learning rate | |
type Rate | |
= Double | |
-- | Number of iterations performed | |
type Count | |
= Int | |
-- | |
predict :: Model -> Observation -> Response | |
predict m x = V.sum $ V.zipWith (*) m x | |
error :: Model -> (Observation, Response) -> Error | |
error m (x, y) = predict m x - y | |
-- Batch | |
learn :: Rate -> [(Observation, Response)] -> (Model, Error, Count) | |
learn rate xs = aux 0 (second msqe $ update rate xs zero) | |
where | |
zero = fmap (const 0) (fst . head $ xs) | |
msqe e = sqrt (e / genericLength xs) | |
aux k (m, e) = case compare (e - e') 0.00000001 of | |
GT -> aux (k + 1) (m', e') | |
-- Warning: the magic constant above should be a function of the | |
-- learning rate. This stops when the error doesn't improve above | |
-- an some amount, but can happen when the learning rate is too low! | |
_ -> (m', e', k + 1) | |
where (m', e') = second msqe $ update rate xs m | |
update :: Rate -> [(Observation, Response)] -> Model -> (Model, Error) | |
update rate xs m | |
-- model-vector := model-vector - rate * sum-vector | |
= (V.zipWith (-) m (fmap (rate *) sum), err) | |
where | |
zero = fmap (const 0) m | |
(sum, err) = foldr op (zero, 0) xs | |
-- sum-vector := sum-vector - error * x-vector | |
op (x, y) = V.zipWith (+) (fmap (e *) x) *** (e^2 +) | |
where e = error m (x, y) | |
-- Stochastic | |
learn_ :: Rate -> [(Observation, Response)] -> Model | |
learn_ = undefined | |
update_ :: Rate -> (Observation, Response) -> Model -> Model | |
update_ rate (x, y) m = V.zipWith (-) m diff | |
where | |
-- model-vector := model-vector - rate * error * x-vector | |
diff = fmap (rate * error m (x, y) *) x | |
anscombe1, anscombe2, anscombe3, anscombe4 :: [(Observation, Response)] | |
anscombe1 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,8.04),(8,6.95),(13,7.58),(9,8.81),(11,8.33),(14,9.96),(6,7.24),(4,4.26),(12,10.84),(7,4.82),(5,5.68)] | |
anscombe2 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,9.14),(8,8.14),(13,8.74),(9,8.77),(11,9.26),(14,8.10),(6,6.13),(4,3.10),(12,9.13),(7,7.26),(5,4.74)] | |
anscombe3 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,7.46),(8,6.77),(13,12.74),(9,7.11),(11,7.81),(14,8.84),(6,6.08),(4,5.39),(12,8.15),(7,6.42),(5,5.73)] | |
anscombe4 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(8,6.58),(8,5.76),(8,7.71),(8,8.84),(8,8.47),(8,7.04),(8,5.25),(19,12.50),(8,5.56),(8,7.91),(8,6.89)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment