Created
September 26, 2012 06:33
-
-
Save mmitou/3786452 to your computer and use it in GitHub Desktop.
PRML 1章の最尤推定を実装した
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import System.Random | |
| import Control.Monad | |
| import Numeric.LinearAlgebra | |
| import Graphics.Gnuplot.Simple | |
| import Data.List | |
| import System.Environment | |
| import System.Exit | |
| trainingDataFileName = "training.dat" | |
| data FittingResult a = FittingResult (a -> a) a | |
| data PlotData a = PlotData { | |
| xs :: [a], | |
| means :: [a], | |
| uppers :: [a], | |
| lowers :: [a] | |
| } | |
| -- 最尤推定 | |
| maximunLikelihoodEstimation :: (Field a, Floating a) => [a] -> [a] -> Int -> FittingResult a | |
| maximunLikelihoodEstimation xs ts m = FittingResult polynomialCurve (ppi ** 0.5) | |
| where | |
| n = genericLength xs | |
| -- 多項式曲線パラメータ | |
| ws = [mtrxW @@> (i, 0) | i <- [0 .. m]] | |
| where | |
| m' = m + 1 | |
| (.) = (flip map) | |
| -- Ai,j と Ti は 演習1.1を参照せよ | |
| a i j = sum $ xs . (^ (i + j)) | |
| t i = sum $ zipWith (*) ts $ xs . (^ i) | |
| -- Aw = T の行列として扱い、連立方程式を解く | |
| mtrxA = (m' >< m') [a i j | i <- [0 .. m], j <- [0 .. m]] | |
| mtrxT = (m' >< 1 ) [t i | i <- [0 .. m]] | |
| mtrxW = linearSolve mtrxA mtrxT | |
| -- 多項式曲線 | |
| polynomialCurve x = sum . zipWith (*) ws $ scanl (*) 1 $ repeat x | |
| -- percitionParamaterInverse 精度パラメータβの逆数 | |
| ppi = (sum $ zipWith (\x t -> (polynomialCurve x - t) ** 2.0) xs ts) / n | |
| -- ノイズを付与する | |
| addNoise :: (RandomGen g, Random a, Num a) => (a, a) -> [a] -> g -> [a] | |
| addNoise range xs g = zipWith (+) xs (randomRs range g) | |
| -- フィッティング結果プロット用データを計算する | |
| calcPlotData :: Fractional a => (a -> a) -> a -> a -> a -> PlotData a | |
| calcPlotData f sigma xmin xmax = PlotData xs ms us ls | |
| where | |
| xs = linearScale 100 (xmin, xmax) | |
| ms = map f xs | |
| us = map (+ sigma) ms | |
| ls = map (\m -> m - sigma) ms | |
| pointsToString :: (Show a) => [a] -> [a] -> String | |
| pointsToString xs ys = concat $ zipWith (\x y -> show x ++ " " ++ show y ++ "\n") xs ys | |
| usage = do | |
| putStrLn "usage:" | |
| putStrLn " ./mle orderNum trainingDataNum" | |
| exitFailure | |
| main = do | |
| args <- getArgs | |
| when (length args /= 2) | |
| usage | |
| gen <- getStdGen | |
| let m = read $ args !! 0 | |
| tnum = read $ args !! 1 | |
| xs = linearScale tnum (0,2*pi) | |
| -- 訓練データの作成 | |
| ts = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin xs) gen | |
| -- 最尤推定 | |
| (FittingResult f s) = maximunLikelihoodEstimation xs ts m | |
| -- 最尤推定で得られた多項式曲線と分散のプロット用データ作成 | |
| (PlotData xs' ms us ls) = calcPlotData f s 0 (2*pi) | |
| tstr = pointsToString xs ts | |
| mstr = pointsToString xs' ms | |
| ustr = pointsToString xs' us | |
| lstr = pointsToString xs' ls | |
| writeFile "train.dat" tstr | |
| writeFile "means.dat" mstr | |
| writeFile "uppers.dat" ustr | |
| writeFile "lowers.dat" lstr | |
| return () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment