Last active
July 30, 2022 15:00
-
-
Save mrkgnao/a45059869590d59f05100f4120595623 to your computer and use it in GitHub Desktop.
A quick Idris implementation of @mstksg's "dependent Haskell" neural networks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main | |
import Data.Vect | |
-- %hide transpose | |
dot : Num a => Vect n a -> Vect n a -> a | |
dot va vb = foldr (+) 0 $ zipWith (*) va vb | |
Matrix : (rows : Nat) -> (cols : Nat) -> Type -> Type | |
Matrix r c a = Vect r (Vect c a) | |
data Layer : Nat -> Nat -> Type -> Type where | |
MkLayer : (biases : Vect o a) | |
-> (weights : Matrix o i a) | |
-> Layer i o a | |
infixr 5 :>: | |
data Network : Nat -> List Nat -> Nat -> Type -> Type where | |
Output : Layer i o a | |
-> Network i [] o a | |
(:>:) : Layer i h a | |
-> Network h hs o a | |
-> Network i (h :: hs) o a | |
infixl 9 .* | |
(.*) : Num a => Matrix m n a -> Vect n a -> Vect m a | |
mat .* vec = map (dot vec) mat | |
Num a => Num (Vect n a) where | |
(+) = liftA2 (+) | |
(*) = liftA2 (*) | |
fromInteger {n} = replicate n . fromInteger | |
interface Scalable a where | |
scale : Double -> a -> a | |
Scalable Double where | |
scale = (*) | |
Scalable a => Scalable (Vect n a) where | |
scale lambda = map (scale lambda) | |
infixl 9 #*# | |
(#*#) : Num a => Matrix i j a -> Matrix j k a -> Matrix i k a | |
A #*# B = map (\Aj => map (dot Aj) (transpose B)) A | |
Neg a => Neg (Vect n a) where | |
(-) = liftA2 (-) | |
negate = map negate | |
abs = map abs | |
sigmoidD : Double -> Double | |
sigmoidD a = 1 / (1 + exp (-a)) | |
sigmoidD' : Double -> Double | |
sigmoidD' a = let s = sigmoidD a | |
in s * (1 - s) | |
sigmoid : Vect n Double -> Vect n Double | |
sigmoid = map sigmoidD | |
sigmoid' : Vect n Double -> Vect n Double | |
sigmoid' = map sigmoidD' | |
calc : Vect i Double | |
-> Vect o Double | |
-> Matrix o i Double | |
-> Vect o Double | |
calc input bias weights = weights .* input + bias | |
runLayer : Vect i Double | |
-> Layer i o Double | |
-> Vect o Double | |
runLayer input (MkLayer bias weights) = calc input bias weights | |
runLayerS : Vect i Double | |
-> Layer i o Double | |
-> Vect o Double | |
runLayerS input layer = sigmoid $ runLayer input layer | |
feedForward : Vect i Double | |
-> Network i hs o Double | |
-> Vect o Double | |
feedForward input (l :>: ls) = let input' = (runLayerS input l) | |
in feedForward input' ls | |
feedForward input (Output layer) = runLayerS input layer | |
outer : Num a | |
=> Vect m a | |
-> Vect n a | |
-> Matrix m n a | |
outer vm vn = (transpose [vm]) #*# [vn] | |
predictionError : Vect i Double | |
-> Vect o Double | |
-> Network i hs o Double | |
-> Vect o Double | |
predictionError input target net = target - (feedForward input net) | |
backprop : Double | |
-> Vect i Double | |
-> Vect o Double | |
-> Network i hs o Double | |
-> Network i hs o Double | |
backprop eta input target net = fst (go input target net) | |
where | |
go : Vect i Double | |
-> Vect o Double | |
-> Network i hs o Double | |
-> (Network i hs o Double, Vect i Double) | |
go input target (layer@(MkLayer bias weights) :>: rest) = | |
let y = runLayer input layer | |
output = sigmoid y | |
(rest', dWs') = go output target rest | |
dEdy = (sigmoid' y) * dWs' | |
-- | |
bias' = bias - (eta `scale` dEdy) | |
weights' = weights - (eta `scale` (outer dEdy input)) | |
layer' = (MkLayer bias' weights') | |
dWs = (transpose weights) .* dEdy | |
in (layer' :>: rest', dWs) | |
go input target (Output layer@(MkLayer bias weights)) = | |
let y = runLayer input layer | |
output = sigmoid y | |
dEdy = (sigmoid' y) * (output - target) | |
-- | |
bias' = bias - (eta `scale` dEdy) | |
weights' = weights - (eta `scale` (outer dEdy input)) | |
layer' = (MkLayer bias' weights') | |
dWs = (transpose weights) .* dEdy | |
in (Output layer', dWs) | |
initialNet : Network 2 [2] 2 Double | |
initialNet = first :>: second | |
where first = MkLayer [0.35, 0.35] | |
[ [0.15, 0.20] | |
, [0.25, 0.30] | |
] | |
second = Output | |
$ MkLayer [0.60, 0.60] | |
[ [0.40, 0.45] | |
, [0.50, 0.55] | |
] | |
input : Vect 2 Double | |
input = [0.05,0.10] | |
-- should be "target", but meh | |
output : Vect 2 Double | |
output = [0.01,0.99] | |
main : IO () | |
main = | |
let step = backprop 0.5 input output | |
errorF = predictionError input output | |
states = iterate step initialNet | |
in putStrLn . unlines $ map (show . errorF) (take 100 states) | |
{- | |
[-0.7413650695523157, 0.2170715346785375] | |
[-0.7184417622337655, 0.2116230796990295] | |
[-0.693669009336258, 0.2064910449028144] | |
[-0.6672418233176436, 0.2016413051023132] | |
[-0.6394686816095002, 0.1970437287878285] | |
[-0.6107602378741731, 0.1926721758900757] | |
[-0.5815991457088593, 0.1885044535893083] | |
[-0.5524951990947665, 0.1845221406513488] | |
[-0.5239353912308239, 0.1807102331941322] | |
[-0.4963403802849389, 0.1770566276112577] | |
[-0.4700361147284623, 0.1735515109102654] | |
[-0.4452436859544368, 0.1701867509351925] | |
[-0.4220849349263245, 0.1669553649022958] | |
[-0.4005982141168513, 0.1638511093542886] | |
[-0.3807583637710784, 0.1608681981538362] | |
[-0.3624963957286282, 0.1580011304078014] | |
[-0.3457163205239238, 0.1552445998815962] | |
[-0.3303081794851857, 0.152593457729155] | |
[-0.3161573737586014, 0.1500427059155249] | |
[-0.3031508774128172, 0.1475875055225604] | |
[-0.2911810603416536, 0.1452231900792029] | |
[-0.2801477961784186, 0.1429452784460385] | |
[-0.2699594052550016, 0.140749484669611] | |
[-0.2605328461773582, 0.1386317239463343] | |
[-0.2517934502183678, 0.1365881147671741] | |
[-0.2436743990158697, 0.1346149777502912] | |
[-0.2361160771575853, 0.1327088318188883] | |
[-0.2290653827858677, 0.1308663883801339] | |
[-0.222475046421264, 0.1290845440890528] | |
[-0.2163029864420817, 0.1273603726841057] | |
[-0.2105117156600286, 0.1256911162825736] | |
[-0.2050678046903731, 0.1240741764348239] | |
[-0.199941402554362, 0.1225071051611767] | |
[-0.1951058119511725, 0.1209875961337917] | |
[-0.1905371150740613, 0.1195134761175647] | |
[-0.1862138451756125, 0.1180826967465477] | |
[-0.1821166989543435, 0.1166933266839454] | |
[-0.1782282850105121, 0.1153435441924597] | |
[-0.1745329039580395, 0.1140316301261325] | |
[-0.1710163561922812, 0.1127559613435504] | |
[-0.1676657737459223, 0.111515004534339] | |
[-0.1644694730863394, 0.1103073104454442] | |
[-0.1614168261005721, 0.1091315084901476] | |
[-0.1584981468707465, 0.1079863017206288] | |
[-0.1557045921609243, 0.1068704621437445] | |
[-0.1530280738165937, 0.1057828263593242] | |
[-0.1504611815227601, 0.104722291500412] | |
[-0.147997114579025, 0.1036878114553911] | |
[-0.1456296215336469, 0.1026783933526751] | |
[-0.1433529466768198, 0.1016930942895522] | |
[-0.1411617825295137, 0.1007310182877622] | |
[-0.1390512275811833, 0.09979131345942471] | |
[-0.1370167486300981, 0.09887316936797841] | |
[-0.1350541471663034, 0.09797581456982185] | |
[-0.1331595293113357, 0.09709851432334826] | |
[-0.1313292788925082, 0.09624056845301787] | |
[-0.129560033284393, 0.09540130935702407] | |
[-0.1278486616973177, 0.0945801001479607] | |
[-0.12619224563339, 0.09377633291670329] | |
[-0.1245880612656923, 0.09298942711045954] | |
[-0.1230335635266569, 0.09221882801664594] | |
[-0.1215263717179235, 0.09146400534488552] | |
[-0.1200642564767705, 0.09072445190002432] | |
[-0.1186451279540033, 0.08999968233960876] | |
[-0.1172670250753927, 0.08928923200977779] | |
[-0.115928105773743, 0.08859265585399401] | |
[-0.1146266380917467, 0.08790952738946012] | |
[-0.113360992067203, 0.08723943774647369] | |
[-0.112129632322174, 0.08658199476633321] | |
[-0.1109311112864131, 0.08593682215373866] | |
[-0.1097640629930862, 0.08530355867995043] | |
[-0.1086271973915681, 0.08468185743323953] | |
[-0.1075192951280463, 0.0840713851134347] | |
[-0.1064392027499095, 0.08347182136760178] | |
[-0.1053858282945339, 0.08288285816411867] | |
[-0.1043581372271753, 0.0823041992026059] | |
[-0.1033551486963057, 0.0817355593573641] | |
[-0.102375932077947, 0.08117666415213609] | |
[-0.101419603783417, 0.08062724926417586] | |
[-0.1004853243074385, 0.08008706005574517] | |
[-0.09957229549583126, 0.07955585113129793] | |
[-0.09867975801401846, 0.07903338591873366] | |
[-0.09780698899938538, 0.07851943627321689] | |
[-0.09695329988213586, 0.07801378210216214] | |
[-0.09611803436073543, 0.0775162110100851] | |
[-0.09530056651932017, 0.07702651796210702] | |
[-0.09450029907561171, 0.07654450496498333] | |
[-0.09371666174891867, 0.0760699807646058] | |
[-0.09294910973874204, 0.07560276055899551] | |
[-0.09219712230534606, 0.07514266572587369] | |
[-0.09146020144441579, 0.07468952356395442] | |
[-0.09073787064860862, 0.07424316704715994] | |
[-0.09002967374942897, 0.07380343459101679] | |
[-0.08933517383341218, 0.07337016983053202] | |
[-0.08865395222711617, 0.0729432214088982] | |
[-0.08798560754587555, 0.07252244277641773] | |
[-0.08732975480169207, 0.07210769199907441] | |
[-0.08668602456601479, 0.07169883157621615] | |
[-0.08605406218350631, 0.07129572826684816] | |
[-0.08543352703320682, 0.07089825292406515] | |
-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment