Skip to content

Instantly share code, notes, and snippets.

@lotz84
Last active December 26, 2019 07:35
Show Gist options
  • Save lotz84/4f6ae2816da6a9f5c81b29bba3844873 to your computer and use it in GitHub Desktop.
Save lotz84/4f6ae2816da6a9f5c81b29bba3844873 to your computer and use it in GitHub Desktop.
"Fast and Accurate Least-Mean-Squares Solvers" in Haskell https://arxiv.org/abs/1906.04705
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Main where
import Prelude hiding ((<>))
import Data.List
import Data.Maybe
import Data.Proxy
import GHC.TypeLits
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import qualified Numeric.LinearAlgebra as H
import Numeric.LinearAlgebra.Static
import Numeric.LinearAlgebra.Static.Vector
import System.Random.MWC hiding (create)
type WeightedPoint d = (R d, Double)
chunksOf :: Int -> [a] -> [[a]]
chunksOf _ [] = []
chunksOf k as = take k as : chunksOf k (drop k as)
caratheodorySet :: forall d. KnownNat d => IntMap (WeightedPoint d) -> IntMap (WeightedPoint d)
caratheodorySet ps =
let d = fromIntegral $ natVal (Proxy @d)
in if IntMap.size ps <= d + 1
then ps
else let us = IntMap.map snd ps
((k1, (p1, _)), ps') = IntMap.deleteFindMin ps
as = IntMap.map (subtract p1 . fst) $ ps'
vs = IntMap.fromList
$ zip (map fst $ IntMap.toAscList as)
$ (map snd $ IntMap.toAscList as)
`withColumns`
(`withNullspace` (H.toList . extract . head . toColumns))
vs' = IntMap.insert k1 (- sum vs) vs
a = minimum . map snd . IntMap.toList . IntMap.filter (>0) $ IntMap.unionWith (/) us vs'
ws = IntMap.unionWith (-) us $ IntMap.map (* a) vs'
ss = IntMap.filter (\(_, w) -> w > 0) $ IntMap.intersectionWith (,) (IntMap.map fst ps) ws
in caratheodorySet ss
fastCaratheodorySet :: forall d. KnownNat d => Int -> IntMap (WeightedPoint d) -> IntMap (WeightedPoint d)
fastCaratheodorySet k ps =
let d = fromIntegral $ natVal (Proxy @d)
in if IntMap.size ps <= d + 1
then ps
else let partitions = IntMap.fromList $ zip [1..] $ chunksOf (length ps `div` k) (IntMap.toList ps)
us = IntMap.map (sum . map (snd . snd)) partitions
ms = IntMap.map (\(u, xs) -> (map (fst . snd) xs) `withColumns` (\m -> m #> (vector $ map ((/u) . snd . snd) xs))) $ IntMap.intersectionWith (,) us partitions
cs = caratheodorySet $ IntMap.intersectionWith (,) ms us
ss = IntMap.fromList . concat . map snd . IntMap.toList
$ IntMap.map (\(u', ((m, w), xs)) -> map (\(k, (p, u)) -> (k, (p, w * u / u'))) xs)
$ IntMap.intersectionWith (,) us $ IntMap.intersectionWith (,) cs partitions
in fastCaratheodorySet k ss
caratheodoryMatrix :: forall d. KnownNat d => Int -> [R d] -> [R d]
caratheodoryMatrix k ps =
let d = fromIntegral $ natVal (Proxy @d)
n = fromIntegral $ length ps
ps' = IntMap.fromList $ zip [1..] ps
cs = fastCaratheodorySet k
$ IntMap.map (\v -> (fromJust . create . H.flatten . extract $ v `outer` v, 1 / n))
$ ps' :: IntMap (WeightedPoint (d * d))
ss = IntMap.map (\((_, w), p) -> dvmap (*(sqrt (n * w))) p) $ IntMap.intersectionWith (,) cs ps'
in map snd $ IntMap.toAscList ss
covariance m = tr m <> m
randVec :: forall n. KnownNat n => IO (R n)
randVec =
let n = fromIntegral $ natVal (Proxy @n)
in vector <$> (withSystemRandom . asGenIO $ (sequence . replicate n . uniformR (0, 1 :: Double)))
main :: IO ()
main = do
let sampleN = 1000
ps <- sequence $ replicate sampleN (randVec :: IO (R 2))
us <- withSystemRandom . asGenIO $ \gen -> do
us <- sequence . replicate sampleN $ uniformR (0, 1 :: Double) gen
pure $ map (/ sum us) us
mapM_ print (zip ps us)
putStrLn "分散共分散行列(Before)"
print $ ps `withRows` covariance
putStrLn "分散共分散行列(After)"
print $ (caratheodoryMatrix 6 ps) `withRows` covariance
putStr "内分点: "
print $ ps `withColumns` (\m -> m #> vector us)
let (ss, ws) = unzip $ map snd $ IntMap.toList $ caratheodorySet $ IntMap.fromList $ zip [1.. ]$ zip ps us
putStrLn "カラテオドリ集合:_"
mapM_ print (zip ss ws)
putStr "内分点: "
print $ ss `withColumns` (\m -> m #> vector ws)
let (ss, ws) = unzip $ map snd $ IntMap.toList $ fastCaratheodorySet 4 $ IntMap.fromList $ zip [1.. ]$ zip ps us
putStrLn "カラテオドリ集合(fast):_"
mapM_ print (zip ss ws)
putStr "内分点: "
print $ ss `withColumns` (\m -> m #> vector ws)
...
executables:
caratheodory-set-exe:
main: Main.hs
source-dirs: app
ghc-options:
- -threaded
- -rtsopts
- -with-rtsopts=-N
dependencies:
- containers
- ghc-typelits-knownnat
- ghc-typelits-natnormalise
- hmatrix
- hmatrix-vector-sized
- mwc-random
- vector-sized
- caratheodory-set
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment