Skip to content

Instantly share code, notes, and snippets.

@kputnam
Last active August 29, 2015 14:08
Show Gist options
  • Save kputnam/a904c05f7560d6085a23 to your computer and use it in GitHub Desktop.
Save kputnam/a904c05f7560d6085a23 to your computer and use it in GitHub Desktop.
Gradient Descent Linear Regression
module Linear
where
import Prelude hiding (error)
import Control.Arrow
import Data.List
import Data.Vector (Vector)
import qualified Data.Vector as V
-- | Independent variables from a single observation
type Observation
= Vector Double
-- | Dependent variable
type Response
= Double
-- | Vector of coefficients to a linear equation
type Model
= Vector Double
type Error
= Double
-- | Learning rate
type Rate
= Double
-- | Number of iterations performed
type Count
= Int
--
predict :: Model -> Observation -> Response
predict m x = V.sum $ V.zipWith (*) m x
error :: Model -> (Observation, Response) -> Error
error m (x, y) = predict m x - y
-- Batch
learn :: Rate -> [(Observation, Response)] -> (Model, Error, Count)
learn rate xs = aux 0 (second msqe $ update rate xs zero)
where
zero = fmap (const 0) (fst . head $ xs)
msqe e = sqrt (e / genericLength xs)
aux k (m, e) = case compare (e - e') 0.00000001 of
GT -> aux (k + 1) (m', e')
-- Warning: the magic constant above should be a function of the
-- learning rate. This stops when the error doesn't improve above
-- an some amount, but can happen when the learning rate is too low!
_ -> (m', e', k + 1)
where (m', e') = second msqe $ update rate xs m
update :: Rate -> [(Observation, Response)] -> Model -> (Model, Error)
update rate xs m
-- model-vector := model-vector - rate * sum-vector
= (V.zipWith (-) m (fmap (rate *) sum), err)
where
zero = fmap (const 0) m
(sum, err) = foldr op (zero, 0) xs
-- sum-vector := sum-vector - error * x-vector
op (x, y) = V.zipWith (+) (fmap (e *) x) *** (e^2 +)
where e = error m (x, y)
-- Stochastic
learn_ :: Rate -> [(Observation, Response)] -> Model
learn_ = undefined
update_ :: Rate -> (Observation, Response) -> Model -> Model
update_ rate (x, y) m = V.zipWith (-) m diff
where
-- model-vector := model-vector - rate * error * x-vector
diff = fmap (rate * error m (x, y) *) x
anscombe1, anscombe2, anscombe3, anscombe4 :: [(Observation, Response)]
anscombe1 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,8.04),(8,6.95),(13,7.58),(9,8.81),(11,8.33),(14,9.96),(6,7.24),(4,4.26),(12,10.84),(7,4.82),(5,5.68)]
anscombe2 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,9.14),(8,8.14),(13,8.74),(9,8.77),(11,9.26),(14,8.10),(6,6.13),(4,3.10),(12,9.13),(7,7.26),(5,4.74)]
anscombe3 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(10,7.46),(8,6.77),(13,12.74),(9,7.11),(11,7.81),(14,8.84),(6,6.08),(4,5.39),(12,8.15),(7,6.42),(5,5.73)]
anscombe4 = fmap (\(x,y) -> (V.fromList [1,x], y)) [(8,6.58),(8,5.76),(8,7.71),(8,8.84),(8,8.47),(8,7.04),(8,5.25),(19,12.50),(8,5.56),(8,7.91),(8,6.89)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment