Skip to content

Instantly share code, notes, and snippets.

@erutuf
Created December 11, 2015 08:59
Show Gist options
  • Save erutuf/bc0af70577d9f8e37384 to your computer and use it in GitHub Desktop.
Save erutuf/bc0af70577d9f8e37384 to your computer and use it in GitHub Desktop.
module NeuralNetwork where
import Data.List
import System.Random
import Numeric.LinearAlgebra
learnCnst = 0.1
sigmoid x = 1.0 / (1.0 + exp (-x))
dSigmoid x = (1 - sigmoid x ) * sigmoid x
forward :: Matrix Double -> Vector Double -> Vector Double
forward w o = w <> o
errorOut :: Vector Double -> Vector Double -> Vector Double
errorOut teacher o = teacher - o
errorMid :: Matrix Double -> Vector Double -> Vector Double -> Vector Double
errorMid w nextErr i = (trans w <> nextErr) * (cmap dSigmoid i)
errors :: Vector Double -> [Matrix Double] ->
[(Vector Double, Vector Double)] -> [Vector Double]
errors teacher weights ios = scanr f initial $ zip ios weights
where
f ((i, _), w) err = errorMid w err i
initial = errorOut teacher $ snd $ last ios
updateWeight :: Matrix Double -> Vector Double -> Vector Double -> Vector Double -> Matrix Double
updateWeight w err i o = w - learnCnst * o `outer` (err * cmap dSigmoid i)
compute :: [Matrix Double] -> Vector Double -> [(Vector Double, Vector Double)]
comptue [] input = [(input, input)]
compute weights input = out $ scanl f (input, cmap sigmoid input) $ init weights
where
f (i,o) w = (forward w o, cmap sigmoid $ forward w o)
out xs = xs ++ [f (last xs) (last weights)]
backProp :: Vector Double -> Vector Double -> [Matrix Double] -> [Matrix Double]
backProp input teacher weights =
map f $ zip3 weights ios $ tail $ errors teacher weights ios
where
f (w, (i, o), e) = updateWeight w e i o
ios = compute weights input
sqError :: Vector Double -> Vector Double -> Double
sqError teacher out = (teacher - out) <.> (teacher - out) * 0.5
trainMain :: [Matrix Double] -> Vector Double -> Vector Double ->
(Double, [Matrix Double])
trainMain weights input teacher = (sqError teacher out, res)
where
out = snd $ last $ compute weights input
res = backProp input teacher weights
trainLoop :: [Matrix Double] -> [(Vector Double, Vector Double)] ->
Double -> (Double, [Matrix Double])
trainLoop weights [] _ = (0.0, weights)
trainLoop weights (dat:rest) err = trainLoop weights' rest $ err + err'
where
(err', weights') = uncurry (trainMain weights) dat
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment