Last active
May 13, 2016 23:15
-
-
Save cormojs/65ef4be25f932ba4365f57cb39fc10ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE NamedFieldPuns #-} | |
module NeuralNetwork where | |
import Data.Array.Repa as Repa | |
import Data.Array.Repa.Algorithms.Matrix | |
import qualified Data.Array.Repa.Operators.Mapping as Mapping | |
import Debug.Trace (trace) | |
type Vector a = Array a DIM2 Double | |
type Matrix = Array U DIM2 Double | |
newtype NN = NN { thetas :: [Matrix] } deriving Show | |
test = do | |
let (nn, xys) = createSampleNN | |
iter = iterate (\nn -> gradDescent 0.2 nn xys) nn | |
nn' = iter !! 10000 | |
print nn' | |
print $ removeBias $ last $ outputs nn' (fst $ xys !! 0) | |
print $ removeBias $ last $ outputs nn' (fst $ xys !! 1) | |
print $ removeBias $ last $ outputs nn' (fst $ xys !! 2) | |
print $ removeBias $ last $ outputs nn' (fst $ xys !! 3) | |
createSampleNN :: (NN, [(Vector U, Vector U)]) | |
createSampleNN = (NN thetas, xys) | |
where thetas :: [Matrix] | |
thetas = [ fromListUnboxed (Z :. 2 :. 3) [ 0.4, -0.6, 0.5, -0.2, 0.3, 0.4] | |
, fromListUnboxed (Z :. 1 :. 3) [ 0.6, -0.5, 0.2 ] | |
] | |
xDim = Z :. 2 :. 1 | |
yDim = Z :. 1 :. 1 | |
xys = [ (fromListUnboxed xDim [ 1.0, 0.0 ], | |
fromListUnboxed yDim [ 1.0 ] ) | |
, (fromListUnboxed xDim [ 0.0, 1.0 ], | |
fromListUnboxed yDim [ 1.0 ]) | |
, (fromListUnboxed xDim [ 0.0, 0.0 ], | |
fromListUnboxed yDim [ 0.0 ]) | |
, (fromListUnboxed xDim [ 1.0, 1.0 ], | |
fromListUnboxed yDim [ 0.0 ]) | |
] | |
gradDescent :: Double -> NN -> [(Vector U, Vector U)] -> NN | |
gradDescent alpha nn@(NN { thetas }) xys = NN $ Prelude.zipWith (\x y -> computeUnboxedS $ x +^ y) thetas corr'' | |
where corr'' = Prelude.map (Mapping.map (\v -> -alpha * v)) corr | |
corr' = Prelude.zipWith (\d theta -> computeUnboxedS $ fromFunction (extent d) (f d theta)) corr thetas | |
f :: Matrix -> Matrix -> DIM2 -> Double | |
f d theta sh@(Z :. 0 :. _) = d ! sh | |
f d theta sh = (d ! sh) + (lambda/m) * (theta ! sh) | |
lambda = 10 | |
corr = Prelude.map (computeUnboxedS . Mapping.map (\v -> v / m)) $ derivSum xys | |
m = fromIntegral $ length xys | |
derivSum [xy] = derivs xy | |
derivSum (xy:xys) = Prelude.zipWith (\x y -> computeUnboxedS $ x +^ y) (derivs xy) (derivSum xys) | |
derivs (x, y) = | |
let outs = outputs nn x in | |
Prelude.zipWith mmultS (tail $ deltas nn outs y) (Prelude.map transpose2S outs) | |
deltas :: NN -> [Vector U] -> Vector U -> [Vector U] | |
deltas (NN { thetas }) outs y = scanr f deltaLast $ zip outs thetas | |
where deltaLast = computeUnboxedS $ (removeBias $ last outs) -^ y | |
f :: (Vector U, Matrix) -> Vector U -> Vector U | |
f (out, theta) delta = | |
let out' = removeBias out in | |
let z = Mapping.map (\x -> x * (1 - x)) out' in | |
let td = removeBias $ transpose2S theta `mmultS` delta in | |
let result = computeUnboxedS $ td *^ z in | |
result | |
outputs :: NN -> Vector U -> [Vector U] | |
outputs (NN { thetas }) x = scanl f (addBias x) thetas | |
where f :: Vector U -> Matrix -> Vector U | |
f a theta = | |
addBias | |
$ Mapping.map sigmoid | |
$ (theta `mmultS` a) | |
addBias :: Source r Double => Vector r -> Vector U | |
addBias x = computeUnboxedS $ transpose $ Repa.append one $ transpose x | |
where one :: Vector U | |
one = fromListUnboxed (Z :. 1 :. 1) [ 1.0 ] | |
removeBias :: Vector U -> Vector U | |
removeBias vec = computeUnboxedS $ extract (Z :. 1 :. 0) (Z :. (x-1) :. 1) vec | |
where (Z :. x) :. 1 = extent vec | |
sigmoid :: Double -> Double | |
sigmoid x = 1 / (1 + exp (-x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
できたっぽい。初期値によっては収束しなかったりするけど。
$ test
NN {thetas = [AUnboxed ((Z :. 2) :. 3) [3.4749054720321246,-6.9192284553440295,7.133189993063397,-3.2603535736131346,-6.388039812895748,6.127063628036342],AUnboxed ((Z :. 1) :. 3) [5.174441008588869,-11.069352096283652,11.726324991994492]]}
AUnboxed ((Z :. 1) :. 1) [0.9920981666738042]
AUnboxed ((Z :. 1) :. 1) [0.9945165938256663]
AUnboxed ((Z :. 1) :. 1) [5.8871391416373126e-3]
AUnboxed ((Z :. 1) :. 1) [5.025238147754469e-3]