Created
December 11, 2015 08:59
-
-
Save erutuf/bc0af70577d9f8e37384 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module NeuralNetwork where | |
import Data.List | |
import System.Random | |
import Numeric.LinearAlgebra | |
learnCnst = 0.1 | |
sigmoid x = 1.0 / (1.0 + exp (-x)) | |
dSigmoid x = (1 - sigmoid x ) * sigmoid x | |
forward :: Matrix Double -> Vector Double -> Vector Double | |
forward w o = w <> o | |
errorOut :: Vector Double -> Vector Double -> Vector Double | |
errorOut teacher o = teacher - o | |
errorMid :: Matrix Double -> Vector Double -> Vector Double -> Vector Double | |
errorMid w nextErr i = (trans w <> nextErr) * (cmap dSigmoid i) | |
errors :: Vector Double -> [Matrix Double] -> | |
[(Vector Double, Vector Double)] -> [Vector Double] | |
errors teacher weights ios = scanr f initial $ zip ios weights | |
where | |
f ((i, _), w) err = errorMid w err i | |
initial = errorOut teacher $ snd $ last ios | |
updateWeight :: Matrix Double -> Vector Double -> Vector Double -> Vector Double -> Matrix Double | |
updateWeight w err i o = w - learnCnst * o `outer` (err * cmap dSigmoid i) | |
compute :: [Matrix Double] -> Vector Double -> [(Vector Double, Vector Double)] | |
comptue [] input = [(input, input)] | |
compute weights input = out $ scanl f (input, cmap sigmoid input) $ init weights | |
where | |
f (i,o) w = (forward w o, cmap sigmoid $ forward w o) | |
out xs = xs ++ [f (last xs) (last weights)] | |
backProp :: Vector Double -> Vector Double -> [Matrix Double] -> [Matrix Double] | |
backProp input teacher weights = | |
map f $ zip3 weights ios $ tail $ errors teacher weights ios | |
where | |
f (w, (i, o), e) = updateWeight w e i o | |
ios = compute weights input | |
sqError :: Vector Double -> Vector Double -> Double | |
sqError teacher out = (teacher - out) <.> (teacher - out) * 0.5 | |
trainMain :: [Matrix Double] -> Vector Double -> Vector Double -> | |
(Double, [Matrix Double]) | |
trainMain weights input teacher = (sqError teacher out, res) | |
where | |
out = snd $ last $ compute weights input | |
res = backProp input teacher weights | |
trainLoop :: [Matrix Double] -> [(Vector Double, Vector Double)] -> | |
Double -> (Double, [Matrix Double]) | |
trainLoop weights [] _ = (0.0, weights) | |
trainLoop weights (dat:rest) err = trainLoop weights' rest $ err + err' | |
where | |
(err', weights') = uncurry (trainMain weights) dat | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment