Created
January 10, 2014 13:41
-
-
Save lubomir/5d84c6da5b53c7d95cde to your computer and use it in GitHub Desktop.
Part 3 of final assignment for Complex Valued Neural Networks with Multi Valued Neurons course
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Iris where | |
import Data.Vector (Vector) | |
import Control.Arrow | |
import qualified Data.Vector as V | |
import Data.Complex | |
learning :: Vector (Vector (Complex Double), Int) | |
learning = V.fromList $ map (first $ V.fromList . map cis) | |
[ ([0.31, 0.873, 0.095, 0.058], 0) | |
, ([0.233, 0.582, 0.095, 0.058], 0) | |
, ([0.155, 0.698, 0.071, 0.058], 0) | |
, ([0.116, 0.64, 0.118, 0.058], 0) | |
, ([0.271, 0.931, 0.095, 0.058], 0) | |
, ([0.427, 1.105, 0.166, 0.175], 0) | |
, ([0.116, 0.814, 0.095, 0.116], 0) | |
, ([0.271, 0.814, 0.118, 0.058], 0) | |
, ([0.039, 0.524, 0.095, 0.058], 0) | |
, ([0.233, 0.64, 0.118, 0], 0) | |
, ([0.427, 0.989, 0.118, 0.058], 0) | |
, ([0.194, 0.814, 0.142, 0.058], 0) | |
, ([0.194, 0.582, 0.095, 0], 0) | |
, ([0, 0.582, 0.024, 0], 0) | |
, ([0.582, 1.164, 0.047, 0.058], 0) | |
, ([0.543, 1.396, 0.118, 0.175], 0) | |
, ([0.427, 1.105, 0.071, 0.175], 0) | |
, ([0.31, 0.873, 0.095, 0.116], 0) | |
, ([0.543, 1.047, 0.166, 0.116], 0) | |
, ([0.31, 1.047, 0.118, 0.116], 0) | |
, ([0.427, 0.814, 0.166, 0.058], 0) | |
, ([0.31, 0.989, 0.118, 0.175], 0) | |
, ([0.116, 0.931, 0, 0.058], 0) | |
, ([0.31, 0.756, 0.166, 0.233], 0) | |
, ([0.194, 0.814, 0.213, 0.058], 0) | |
, ([0.271, 0.582, 0.142, 0.058], 0) | |
, ([0.271, 0.814, 0.142, 0.175], 0) | |
, ([0.349, 0.873, 0.118, 0.058], 0) | |
, ([0.349, 0.814, 0.095, 0.058], 0) | |
, ([0.155, 0.698, 0.142, 0.058], 0) | |
, ([0.621, 0.698, 0.899, 0.989], 1) | |
, ([0.698, 0.465, 0.71, 0.698], 1) | |
, ([0.776, 0.291, 0.923, 0.814], 1) | |
, ([0.698, 0.465, 0.876, 0.64], 1) | |
, ([0.814, 0.524, 0.781, 0.698], 1) | |
, ([0.892, 0.582, 0.805, 0.756], 1) | |
, ([0.97, 0.465, 0.899, 0.756], 1) | |
, ([0.931, 0.582, 0.947, 0.931], 1) | |
, ([0.659, 0.524, 0.828, 0.814], 1) | |
, ([0.543, 0.349, 0.592, 0.524], 1) | |
, ([0.465, 0.233, 0.663, 0.582], 1) | |
, ([0.465, 0.233, 0.639, 0.524], 1) | |
, ([0.582, 0.407, 0.686, 0.64], 1) | |
, ([0.659, 0.407, 0.97, 0.873], 1) | |
, ([0.427, 0.582, 0.828, 0.814], 1) | |
, ([0.659, 0.814, 0.828, 0.873], 1) | |
, ([0.931, 0.64, 0.876, 0.814], 1) | |
, ([0.776, 0.175, 0.805, 0.698], 1) | |
, ([0.504, 0.582, 0.734, 0.698], 1) | |
, ([0.465, 0.291, 0.71, 0.698], 1) | |
, ([0.465, 0.349, 0.805, 0.64], 1) | |
, ([0.698, 0.582, 0.852, 0.756], 1) | |
, ([0.582, 0.349, 0.71, 0.64], 1) | |
, ([0.271, 0.175, 0.544, 0.524], 1) | |
, ([0.504, 0.407, 0.757, 0.698], 1) | |
, ([0.543, 0.582, 0.757, 0.64], 1) | |
, ([0.543, 0.524, 0.757, 0.698], 1) | |
, ([0.737, 0.524, 0.781, 0.698], 1) | |
, ([0.31, 0.291, 0.473, 0.582], 1) | |
, ([0.543, 0.465, 0.734, 0.698], 1) | |
, ([0.776, 0.756, 1.183, 1.396], 2) | |
, ([0.582, 0.407, 0.97, 1.047], 2) | |
, ([1.086, 0.582, 1.16, 1.164], 2) | |
, ([0.776, 0.524, 1.089, 0.989], 2) | |
, ([0.853, 0.582, 1.136, 1.222], 2) | |
, ([1.28, 0.582, 1.325, 1.164], 2) | |
, ([0.233, 0.291, 0.828, 0.931], 2) | |
, ([1.164, 0.524, 1.254, 0.989], 2) | |
, ([0.931, 0.291, 1.136, 0.989], 2) | |
, ([1.125, 0.931, 1.207, 1.396], 2) | |
, ([0.853, 0.698, 0.97, 1.105], 2) | |
, ([0.814, 0.407, 1.018, 1.047], 2) | |
, ([0.97, 0.582, 1.065, 1.164], 2) | |
, ([0.543, 0.291, 0.947, 1.105], 2) | |
, ([0.582, 0.465, 0.97, 1.338], 2) | |
, ([0.814, 0.698, 1.018, 1.28], 2) | |
, ([0.853, 0.582, 1.065, 0.989], 2) | |
, ([1.319, 1.047, 1.349, 1.222], 2) | |
, ([1.319, 0.349, 1.396, 1.28], 2) | |
, ([0.659, 0.116, 0.947, 0.814], 2) | |
, ([1.008, 0.698, 1.112, 1.28], 2) | |
, ([0.504, 0.465, 0.923, 1.105], 2) | |
, ([1.319, 0.465, 1.349, 1.105], 2) | |
, ([0.776, 0.407, 0.923, 0.989], 2) | |
, ([0.931, 0.756, 1.112, 1.164], 2) | |
, ([1.125, 0.698, 1.183, 0.989], 2) | |
, ([0.737, 0.465, 0.899, 0.989], 2) | |
, ([0.698, 0.582, 0.923, 0.989], 2) | |
, ([0.814, 0.465, 1.089, 1.164], 2) | |
, ([1.125, 0.582, 1.136, 0.873], 2) | |
] | |
testing :: Vector (Vector (Complex Double), Int) | |
testing = V.fromList $ map (first $ V.fromList . map cis) | |
[ ([0.194, 0.64, 0.142, 0.058], 0) | |
, ([0.427, 0.814, 0.118, 0.175], 0) | |
, ([0.349, 1.222, 0.118, 0], 0) | |
, ([0.465, 1.28, 0.095, 0.058], 0) | |
, ([0.233, 0.64, 0.118, 0.058], 0) | |
, ([0.271, 0.698, 0.047, 0.058], 0) | |
, ([0.465, 0.873, 0.071, 0.058], 0) | |
, ([0.233, 0.931, 0.095, 0], 0) | |
, ([0.039, 0.582, 0.071, 0.058], 0) | |
, ([0.31, 0.814, 0.118, 0.058], 0) | |
, ([0.271, 0.873, 0.071, 0.116], 0) | |
, ([0.078, 0.175, 0.071, 0.116], 0) | |
, ([0.039, 0.698, 0.071, 0.058], 0) | |
, ([0.271, 0.873, 0.142, 0.291], 0) | |
, ([0.31, 1.047, 0.213, 0.175], 0) | |
, ([0.194, 0.582, 0.095, 0.116], 0) | |
, ([0.31, 1.047, 0.142, 0.058], 0) | |
, ([0.116, 0.698, 0.095, 0.058], 0) | |
, ([0.388, 0.989, 0.118, 0.058], 0) | |
, ([0.271, 0.756, 0.095, 0.058], 0) | |
, ([1.047, 0.698, 0.876, 0.756], 1) | |
, ([0.814, 0.698, 0.828, 0.814], 1) | |
, ([1.008, 0.64, 0.923, 0.814], 1) | |
, ([0.465, 0.175, 0.71, 0.698], 1) | |
, ([0.853, 0.465, 0.852, 0.814], 1) | |
, ([0.543, 0.465, 0.828, 0.698], 1) | |
, ([0.776, 0.756, 0.876, 0.873], 1) | |
, ([0.233, 0.233, 0.544, 0.524], 1) | |
, ([0.892, 0.524, 0.852, 0.698], 1) | |
, ([0.349, 0.407, 0.686, 0.756], 1) | |
, ([0.271, 0, 0.592, 0.524], 1) | |
, ([0.621, 0.582, 0.757, 0.814], 1) | |
, ([0.659, 0.116, 0.71, 0.524], 1) | |
, ([0.698, 0.524, 0.876, 0.756], 1) | |
, ([0.504, 0.524, 0.615, 0.698], 1) | |
, ([0.931, 0.64, 0.805, 0.756], 1) | |
, ([0.504, 0.582, 0.828, 0.814], 1) | |
, ([0.582, 0.407, 0.734, 0.524], 1) | |
, ([0.737, 0.116, 0.828, 0.814], 1) | |
, ([0.504, 0.291, 0.686, 0.582], 1) | |
, ([1.202, 0.465, 1.207, 1.047], 2) | |
, ([1.396, 1.047, 1.278, 1.105], 2) | |
, ([0.814, 0.465, 1.089, 1.222], 2) | |
, ([0.776, 0.465, 0.97, 0.814], 2) | |
, ([0.698, 0.349, 1.089, 0.756], 2) | |
, ([1.319, 0.582, 1.207, 1.28], 2) | |
, ([0.776, 0.814, 1.089, 1.338], 2) | |
, ([0.814, 0.64, 1.065, 0.989], 2) | |
, ([0.659, 0.582, 0.899, 0.989], 2) | |
, ([1.008, 0.64, 1.041, 1.164], 2) | |
, ([0.931, 0.64, 1.089, 1.338], 2) | |
, ([1.008, 0.64, 0.97, 1.28], 2) | |
, ([0.582, 0.407, 0.97, 1.047], 2) | |
, ([0.97, 0.698, 1.16, 1.28], 2) | |
, ([0.931, 0.756, 1.112, 1.396], 2) | |
, ([0.931, 0.582, 0.994, 1.28], 2) | |
, ([0.776, 0.291, 0.947, 1.047], 2) | |
, ([0.853, 0.582, 0.994, 1.105], 2) | |
, ([0.737, 0.814, 1.041, 1.28], 2) | |
, ([0.621, 0.582, 0.97, 0.989], 2) | |
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Arrow | |
import Control.Monad (zipWithM_) | |
import Data.Complex | |
import Data.List (minimumBy) | |
import Data.Maybe (isNothing) | |
import Data.Vector (Vector) | |
import qualified Data.Vector as V | |
import System.Random | |
import Text.Printf | |
import Iris | |
-- |Number of sectors to use for classification. | |
type NumSectors = Int | |
-- |Index of a sector. First sector has number 0. | |
type Sector = Int | |
-- |Tuple of input values and a expected output. | |
type LearningSample w r = (Vector w, r) | |
data Options = Options { numSectors :: Int | |
, periodicity :: Int | |
, learningRate :: Complex Double | |
} | |
data MVNP w = MVNP (Vector w) | |
deriving (Show) | |
instance (Random a, RealFloat a) => Random (Complex a) where | |
randomR rng g = let (real, g') = randomR (realPart *** realPart $ rng) g | |
(imag, g'') = randomR (imagPart *** imagPart $ rng) g' | |
in (real :+ imag, g'') | |
random g = let (real,g') = random g | |
(imag,g'') = random g' | |
in (real :+ imag, g'') | |
-- |Given a number of sectors, periodicity coefficient and a point in complex | |
-- plane, find sector to which the point belongs. | |
-- | |
periodicActivation :: (RealFloat a) => Options -> Complex a -> Sector | |
periodicActivation Options{numSectors=n,periodicity=l} s = | |
basicActivation (n*l) s `mod` n | |
-- |Given number of sectors and a point in complex plane, find the sector to | |
-- which this point belongs. | |
-- | |
basicActivation :: (RealFloat a) | |
=> NumSectors -- ^Number of sectors | |
-> Complex a -- ^Argument | |
-> Sector -- ^Index of result sector | |
basicActivation n s = floor $ posPhase s / sectBound n 1 | |
-- |Runs neuron on a given input vector and return the resulting sector. | |
-- | |
runNeuron :: (RealFloat w) | |
=> Options | |
-> MVNP (Complex w) | |
-> Vector (Complex w) | |
-> Sector | |
runNeuron opts n i = periodicActivation opts $ getWeightedSum n i | |
-- |Compute phase a of a complex number. The results is always positive, | |
-- ranging from 0 to π. | |
-- | |
posPhase :: RealFloat a => Complex a -> a | |
posPhase x = let phi = phase x | |
in if phi < 0 then phi + 2 * pi | |
else phi | |
-- |For a number of sectors, periodicity coefficient and and a point in complex | |
-- plane, find phase of sector bound closest to the point that is marked the | |
-- same as a given sector. | |
-- | |
closestSector :: (RealFloat a, Ord a) => Options -> Sector -> Complex a -> a | |
closestSector Options{numSectors=s, periodicity=l} r x = | |
sectBound (s*l) $ minimumBy cmp [r,r+s..s*l-1] | |
where | |
distance m = let low = clamp pi $ sectBound (s*l) m | |
high = clamp pi $ sectBound (s*l) ((m+1) `mod` (s*l)) | |
in min (abs $ low - phase x) (abs $ high - phase x) | |
cmp m n = compare (distance m) (distance n) | |
-- |Make sure number phase is within given range. | |
-- | |
clamp :: (Num a, Ord a) => a -> a -> a | |
clamp m n = if n > m then clamp m (n-m) else n | |
-- |Update weigths of a neuron on given input. The error is passed as an | |
-- argument and is not checked for correctness. | |
-- | |
updateWeights :: (RealFloat a) | |
=> MVNP (Complex a) | |
-> Vector (Complex a) -- ^Input sample | |
-> Complex a -- ^Error on this sample | |
-> MVNP (Complex a) -- ^Updated neuron | |
updateWeights (MVNP ws) i e = MVNP $ V.zipWith update ws (V.cons 1 i) | |
where | |
rate = 1 | |
n = fromIntegral $ V.length i + 1 | |
update w x = w + rate * e * conjugate x / n | |
--- |Find a lower bound of a sector. Returns the phase of the boundary. | |
-- | |
sectBound :: (Floating a, Ord a) => NumSectors -> Sector -> a | |
sectBound n x = fromIntegral x * 2 * pi / fromIntegral n | |
-- |For a give input vector, find the weighted sum that the neuron computes. | |
-- The number of inputs should be 1 less than the number of weigths, as the | |
-- input bias is added automatically. | |
-- | |
getWeightedSum :: (Num w) | |
=> MVNP w -- ^Actual neuron | |
-> Vector w -- ^Input without the leading 1 | |
-> w -- ^Result | |
getWeightedSum (MVNP ws) inp = V.sum $ V.zipWith (*) ws (1 `V.cons` inp) | |
printWeights :: Vector (Complex Double) -> IO () | |
printWeights vs = zipWithM_ p (V.toList vs) [0..] | |
where | |
p :: Complex Double -> Int -> IO () | |
p w i = printf "w_%d = %f + %f i\n" i (realPart w) (imagPart w) | |
-- |Test whether neuron returns correct answer for given sample. If it does | |
-- not, return the error, otherwise return `Nothing`. | |
-- | |
testSample :: (RealFloat a, Show a) | |
=> Options | |
-> MVNP (Complex a) | |
-> LearningSample (Complex a) Sector | |
-> Maybe (Complex a) -- ^Maybe error | |
testSample opts n (i,r) = | |
let actual = runNeuron opts n i | |
z = getWeightedSum n i | |
in case compare actual r of | |
EQ -> Nothing | |
_ -> Just $ closest - cis (sectBound total (basicActivation total z)) | |
where | |
closest = cis $ closestSector opts r z | |
total = numSectors opts * periodicity opts | |
runErrorCorrection :: (RealFloat a, Show a) | |
=> Options | |
-> MVNP (Complex a) | |
-> Vector (LearningSample (Complex a) Sector) | |
-> (MVNP (Complex a), Int) | |
runErrorCorrection opts n' d' = go True 1 n' d' | |
where | |
go r i n d | |
| V.null d = if r then (n,i) else go True (i+1) n d' | |
| otherwise = do | |
let (inp,expected) = V.head d | |
case testSample opts n (inp, expected) of | |
Nothing -> go (r && True) i n (V.tail d) | |
Just err -> go False i (updateWeights n inp err) (V.tail d) | |
main :: IO () | |
main = do | |
let opts = Options{numSectors=3, periodicity=3, learningRate=1} | |
initial <- V.replicateM 5 $ randomRIO ((-0.5):+(-0.5),0.5:+0.5) | |
putStrLn "initial weights: " >> printWeights initial | |
let (result@(MVNP ws), iters) = runErrorCorrection opts | |
(MVNP initial) | |
learning | |
putStrLn "final weights: " >> printWeights ws | |
let correct = V.length $ V.filter isNothing $ V.map (testSample opts result) testing | |
let total = V.length testing | |
let accuracy = fromIntegral correct / fromIntegral total * 100 :: Double | |
putStrLn $ "Done after "++show iters++" iterations" | |
putStrLn $ "Correct "++show correct++" ("++show accuracy++" %)" | |
{- Run with | |
$ for i in $(seq 1 50); do ./part3 | tail -n1; done | awk '{s+=$2;n+=1}END{print s/n}' | |
57.8 | |
-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment