Created
June 1, 2010 01:07
-
-
Save yatsuta/420442 to your computer and use it in GitHub Desktop.
derivative system of Gaussian function and its log likelyhood
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Data.List | |
import Data.Maybe | |
-- **************************************************************** | |
-- * data type definition | |
-- **************************************************************** | |
type VarName = String | |
data Expr | |
= Num Double | |
| Var VarName | |
| Add [Expr] | |
| Mult [Expr] | |
| Minus Expr | |
| Recip Expr | |
| Pow Expr Expr | |
| Exp Expr | |
| Log Expr | |
| Sqrt Expr | |
| Sum VarName [VarName] Expr Expr Expr | |
| Prod VarName [VarName] Expr Expr Expr | |
deriving (Eq, Ord, Show) | |
-- **************************************************************** | |
-- * hasVar | |
-- **************************************************************** | |
Num _ `hasVar` varName = False | |
Var x `hasVar` varName = x == varName | |
Add exprs `hasVar` varName = any (`hasVar` varName) exprs | |
Mult exprs `hasVar` varName = any (`hasVar` varName) exprs | |
Pow x n `hasVar` varName = (x `hasVar` varName) || (n `hasVar` varName) | |
Exp x `hasVar` varName = x `hasVar` varName | |
Log x `hasVar` varName = x `hasVar` varName | |
Sum n xs _ _ body `hasVar` varName = (body `hasVar` varName) && (all (/= varName) (n:xs)) | |
-- **************************************************************** | |
-- * split... | |
-- **************************************************************** | |
isVar (Var _) = True | |
isVar _ = False | |
splitVar exprs = partition isVar exprs | |
splitNum exprs = partition isNum exprs | |
where isNum (Num _) = True | |
isNum _ = False | |
splitAdd exprs = partition isAdd exprs | |
where isAdd (Add _) = True | |
isAdd _ = False | |
splitMultVar exprs = partition isMultVar exprs | |
where isMultVar (Mult exprs') | any isVar exprs' = True | |
isMultVar _ = False | |
splitMult exprs = partition isMult exprs | |
where isMult (Mult _) = True | |
isMult _ = False | |
splitPowVar exprs = partition isPowVar exprs | |
where isPowVar (Pow (Var _) _) = True | |
isPowVar _ = False | |
splitHasVars vars exprs = partition f exprs | |
where f expr = any (expr `hasVar`) vars | |
-- **************************************************************** | |
-- * addNums, multNums | |
-- **************************************************************** | |
addNums nums = foldl' addNum2 (Num 0) nums | |
where addNum2 (Num n1) (Num n2) = Num (n1 + n2) | |
multNums nums = foldl' multNum2 (Num 1) nums | |
where multNum2 (Num n1) (Num n2) = Num (n1 * n2) | |
-- **************************************************************** | |
-- * count... | |
-- **************************************************************** | |
countVar vars = [(var, Num 1) | var <- vars] | |
countMultVar multVars b = [countEach factors | factors <- factorsList] | |
where factorsList = [exprs | Mult exprs <- multVars] | |
countEach factors = | |
let firstVar = fromJust $ find isVar factors | |
in (firstVar, | |
eval (Mult $ delete firstVar factors) b) | |
countPowVar powVars = [(x, n) | Pow x n <- powVars] | |
sumUpVars counts b = addExprs | |
where countsGroup = groupBy (\(v, _) (v', _) -> v == v') $ sort counts | |
varSummedUp = map sumUp countsGroup | |
where sumUp counts = (fst $ head counts, | |
eval (Add $ map snd counts) b) | |
addExprs = map makeVarExpr varSummedUp | |
where makeVarExpr (var, count) = eval (Mult [var, count]) b | |
sumUpVars' counts b = multExprs | |
where countsGroup = groupBy (\(v, _) (v', _) -> v == v') $ sort counts | |
varSummedUp = map sumUp countsGroup | |
where sumUp counts = (fst $ head counts, | |
eval (Add $ map snd counts) b) | |
multExprs = map makeVarExpr varSummedUp | |
where makeVarExpr (var, count) = eval (Pow var count) b | |
indexedVar v i = Var (v ++ (show . round) i) | |
-- **************************************************************** | |
-- * eval... | |
-- **************************************************************** | |
evalAdd (Add exprs) b = expr | |
where exprs1 = [eval expr b | expr <- exprs] | |
(adds, nonAdds) = splitAdd exprs1 | |
exprs2 = concat [exprs | Add exprs <- adds] ++ nonAdds | |
(vars, nonVars) = splitVar exprs2 | |
varCount = countVar vars | |
(multVars, nonMultVars) = splitMultVar nonVars | |
multVarCount = countMultVar multVars b | |
vars' = sumUpVars (varCount ++ multVarCount) b | |
(nums, nonNums) = splitNum $ vars' ++ nonMultVars | |
num = addNums nums | |
exprs3 = case num of | |
Num 0 -> nonNums | |
_ -> num : nonNums | |
expr = case exprs3 of | |
[] -> Num 0 | |
[expr'] -> expr' | |
_ -> Add $ sort exprs3 | |
evalMult (Mult exprs) b = expr | |
where exprs1 = [eval expr b | expr <- exprs] | |
(mults, nonMults) = splitMult exprs1 | |
exprs2 = concat [exprs | Mult exprs <- mults] ++ nonMults | |
(vars, nonVars) = splitVar exprs2 | |
varCount = countVar vars | |
(powVars, nonPowVars) = splitPowVar nonVars | |
powVarCount = countPowVar powVars | |
vars' = sumUpVars' (varCount ++ powVarCount) b | |
(nums, nonNums) = splitNum $ vars' ++ nonPowVars | |
num = multNums nums | |
exprs3 = case num of | |
Num 0 -> [Num 0] | |
Num 1 -> nonNums | |
_ -> num : nonNums | |
expr = case exprs3 of | |
[] -> Num 1 | |
[expr'] -> expr' | |
_ -> Mult $ sort exprs3 | |
evalSumMult i vars begin end exprs b = | |
eval (Mult ((Sum i vars begin end $ Mult hasVars) : notHasVars)) b | |
where (hasVars, notHasVars) = splitHasVars (i:vars) exprs | |
evalSum (Sum i vars begin end body) b = | |
case (eval begin b, eval end b, eval body b) of | |
(begin', end', body') | not (any (body' `hasVar`) (i:vars)) -> | |
eval (Mult [Add [end', Minus begin', Num 1], body']) b | |
(begin', end', Mult exprs) -> | |
evalSumMult i vars begin' end' exprs b | |
(begin', end', Add exprs) -> | |
eval (Add [Sum i vars begin' end' expr | expr <- exprs]) b | |
(Num n1, Num n2, expr') -> eval (Add $ map f [n1 .. n2]) b | |
where f n = let b' = (i, Num n) : | |
zip vars [indexedVar var n | var <- vars] | |
in eval expr' (b ++ b') | |
(begin', end', body') -> Sum i vars begin' end' body' | |
evalProd (Prod i vars begin end body) b = | |
case (eval begin b, eval end b, eval body b) of | |
(begin', end', Log body') -> | |
eval (Sum i vars begin' end' (Log body')) b | |
(Num n1, Num n2, expr') -> eval (Mult $ map f [n1 .. n2]) b | |
where f n = let b' = (i, Num n) : | |
zip vars [indexedVar var n | var <- vars] | |
in eval expr' (b ++ b') | |
(begin', end', body') -> Prod i vars begin' end' body' | |
-- **************************************************************** | |
-- * eval | |
-- **************************************************************** | |
eval :: Expr -> [(VarName, Expr)] -> Expr | |
eval (Num n) _ = Num n | |
eval (Var x) b = case lookup x b of | |
Just e -> simplify e | |
Nothing -> Var x | |
eval a@(Add _) b = evalAdd a b | |
eval m@(Mult _) b = evalMult m b | |
eval (Minus expr) b = eval (Mult [Num (-1), expr]) b | |
eval (Recip expr) b = eval (Pow expr (Num (-1))) b | |
eval (Pow expr1 expr2) b = | |
case (eval expr1 b, eval expr2 b) of | |
(_ , Num 0) -> Num 1 | |
(expr1', Num 1) -> expr1' | |
(Pow expr1_1' expr1_2', expr2') -> | |
eval (Pow expr1_1' (Mult [expr1_2', expr2'])) b | |
(Mult exprs, expr2') -> | |
eval (Mult [Pow expr expr2' | expr <- exprs]) b | |
(Num n1, Num n2) -> Num (n1 ** n2) | |
(expr1', expr2') -> Pow expr1' expr2' | |
eval (Exp expr) b = case eval expr b of | |
Num n -> Num (exp n) | |
Log expr' -> expr' | |
expr' -> Exp expr' | |
eval (Log expr) b = case eval expr b of | |
Num n -> Num (log n) | |
Pow expr1' expr2' -> | |
eval (Mult [expr2', Log expr1']) b | |
Mult exprs -> eval (Add [Log e | e <- exprs]) b | |
Exp expr' -> expr' | |
Prod i vars begin end body -> | |
eval (Sum i vars begin end (Log body)) b | |
expr' -> Log expr' | |
eval (Sqrt expr) b = eval (Pow expr (Num 0.5)) b | |
eval s@(Sum _ _ _ _ _) b = evalSum s b | |
eval p@(Prod _ _ _ _ _) b = evalProd p b | |
-- **************************************************************** | |
-- * gaussian, bern | |
-- **************************************************************** | |
gaussian = Mult [Recip (Sqrt (Mult [Num 2, | |
Var "PI", | |
Var "sigma^2"])), | |
Exp (Minus (Mult [Recip (Mult [Num 2, | |
Var "sigma^2"]), | |
(Pow (Add [Var "x", | |
Minus (Var "mu")] ) | |
(Num 2))]))] | |
standardGaussian = eval gaussian [("PI", Num 3.141592), | |
("mu", Num 0), | |
("sigma^2", Num 1)] | |
logLikelyhood = Log (Prod "n" ["x"] (Num 1) (Var "N") gaussian) | |
bern = Mult [Pow (Var "mu") (Var "x"), | |
Pow (Add [Num 1, Minus (Var "mu")]) | |
(Add [Num 1, Minus (Var "x")])] | |
logLikelyhoodBern = Log (Prod "n" ["x"] (Num 1) (Var "N") bern) | |
-- **************************************************************** | |
-- * simplify | |
-- **************************************************************** | |
simplify expr = eval expr [] | |
-- **************************************************************** | |
-- * deriv | |
-- **************************************************************** | |
deriv' :: Expr -> VarName -> Expr | |
deriv' (Num _) _ = Num 0 | |
deriv' (Var x) v | x == v = Num 1 | |
| otherwise = Num 0 | |
deriv' (Add exprs) v = Add [deriv' expr v | expr <- exprs] | |
deriv' (Mult [expr]) v = deriv' expr v | |
deriv' (Mult (e:exprs)) v = Add [Mult [e, deriv' (Mult exprs) v], | |
Mult [deriv' e v, Mult exprs]] | |
deriv' (Pow x n) v = Mult [Mult [n, | |
Pow x (Add [n, | |
Minus (Num 1)])], | |
deriv' x v] | |
deriv' (Exp x) v = Mult [Exp x, deriv' x v] | |
deriv' (Log x) v = Mult [Recip x, deriv' x v] | |
deriv' (Sum i vars begin end body) v = Sum i vars begin end (deriv' body v) | |
deriv expr v = simplify $ deriv' (simplify expr) v | |
-- **************************************************************** | |
-- * solve | |
-- **************************************************************** | |
solve' :: Expr -> Expr -> VarName -> Expr | |
solve' (Var x) rhs v | x == v = rhs | |
solve' (Add exprs) rhs v = solve' lhs (Add [rhs, Minus rhs']) v | |
where lhs = fromJust $ find (`hasVar` v) exprs | |
rhs' = Add (delete lhs exprs) | |
solve' (Mult exprs) rhs v = solve' lhs (Mult [rhs, Recip rhs']) v | |
where lhs = fromJust $ find (`hasVar` v) exprs | |
rhs' = Mult (delete lhs exprs) | |
solve' (Pow x n) rhs v = solve' x (Pow rhs (Recip n)) v | |
solve' (Exp x) rhs v = solve' x (Log rhs) v | |
solve' (Log x) rhs v = solve' x (Exp rhs) v | |
solve lhs rhs v = simplify $ solve' (simplify lhs) (simplify rhs) v |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment