Skip to content

Instantly share code, notes, and snippets.

@mmitou
Created September 26, 2012 06:33
Show Gist options
  • Save mmitou/3786452 to your computer and use it in GitHub Desktop.
Save mmitou/3786452 to your computer and use it in GitHub Desktop.
PRML 1章の最尤推定を実装した
import System.Random
import Control.Monad
import Numeric.LinearAlgebra
import Graphics.Gnuplot.Simple
import Data.List
import System.Environment
import System.Exit
trainingDataFileName = "training.dat"
data FittingResult a = FittingResult (a -> a) a
data PlotData a = PlotData {
xs :: [a],
means :: [a],
uppers :: [a],
lowers :: [a]
}
-- 最尤推定
maximunLikelihoodEstimation :: (Field a, Floating a) => [a] -> [a] -> Int -> FittingResult a
maximunLikelihoodEstimation xs ts m = FittingResult polynomialCurve (ppi ** 0.5)
where
n = genericLength xs
-- 多項式曲線パラメータ
ws = [mtrxW @@> (i, 0) | i <- [0 .. m]]
where
m' = m + 1
(.) = (flip map)
-- Ai,j と Ti は 演習1.1を参照せよ
a i j = sum $ xs . (^ (i + j))
t i = sum $ zipWith (*) ts $ xs . (^ i)
-- Aw = T の行列として扱い、連立方程式を解く
mtrxA = (m' >< m') [a i j | i <- [0 .. m], j <- [0 .. m]]
mtrxT = (m' >< 1 ) [t i | i <- [0 .. m]]
mtrxW = linearSolve mtrxA mtrxT
-- 多項式曲線
polynomialCurve x = sum . zipWith (*) ws $ scanl (*) 1 $ repeat x
-- percitionParamaterInverse 精度パラメータβの逆数
ppi = (sum $ zipWith (\x t -> (polynomialCurve x - t) ** 2.0) xs ts) / n
-- ノイズを付与する
addNoise :: (RandomGen g, Random a, Num a) => (a, a) -> [a] -> g -> [a]
addNoise range xs g = zipWith (+) xs (randomRs range g)
-- フィッティング結果プロット用データを計算する
calcPlotData :: Fractional a => (a -> a) -> a -> a -> a -> PlotData a
calcPlotData f sigma xmin xmax = PlotData xs ms us ls
where
xs = linearScale 100 (xmin, xmax)
ms = map f xs
us = map (+ sigma) ms
ls = map (\m -> m - sigma) ms
pointsToString :: (Show a) => [a] -> [a] -> String
pointsToString xs ys = concat $ zipWith (\x y -> show x ++ " " ++ show y ++ "\n") xs ys
usage = do
putStrLn "usage:"
putStrLn " ./mle orderNum trainingDataNum"
exitFailure
main = do
args <- getArgs
when (length args /= 2)
usage
gen <- getStdGen
let m = read $ args !! 0
tnum = read $ args !! 1
xs = linearScale tnum (0,2*pi)
-- 訓練データの作成
ts = addNoise (-0.5 :: Double, 0.5 :: Double) (fmap sin xs) gen
-- 最尤推定
(FittingResult f s) = maximunLikelihoodEstimation xs ts m
-- 最尤推定で得られた多項式曲線と分散のプロット用データ作成
(PlotData xs' ms us ls) = calcPlotData f s 0 (2*pi)
tstr = pointsToString xs ts
mstr = pointsToString xs' ms
ustr = pointsToString xs' us
lstr = pointsToString xs' ls
writeFile "train.dat" tstr
writeFile "means.dat" mstr
writeFile "uppers.dat" ustr
writeFile "lowers.dat" lstr
return ()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment