Created
May 14, 2021 03:04
-
-
Save Garciat/ee399adc6d11b58b3527fe35eaefc35c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE RankNTypes #-} | |
module Symbolic where | |
data Expr | |
= Var -- x | |
| Lit Double | |
| BinOp BinOp Expr Expr | |
| UnaOp UnaOp Expr | |
deriving (Show, Eq) | |
data BinOp | |
= Add | |
| Sub | |
| Mul | |
| Div | |
| Pow | |
deriving (Show, Eq) | |
-- a <> b = b <> a | |
isCommut :: BinOp -> Bool | |
isCommut Add = True | |
isCommut Mul = True | |
isCommut _ = False | |
-- (a <> b) <> c = a <> (b <> c) = a <> b <> c | |
isAssoc :: BinOp -> Bool | |
isAssoc Add = True | |
isAssoc Mul = True | |
isAssoc _ = False | |
runBinOp :: BinOp -> Double -> Double -> Double | |
runBinOp Add = (+) | |
runBinOp Sub = (-) | |
runBinOp Mul = (*) | |
runBinOp Div = (/) | |
runBinOp Pow = (**) | |
data UnaOp | |
= Cos | |
| Sin | |
| Tan | |
| Exp | |
| Log -- ln | |
deriving (Show, Eq) | |
-- | |
instance Num Expr where | |
(+) = BinOp Add | |
(-) = BinOp Sub | |
(*) = BinOp Mul | |
negate = BinOp Mul (Lit (-1)) | |
fromInteger i = Lit (fromIntegral i) | |
instance Fractional Expr where | |
(/) = BinOp Div | |
recip = BinOp Div (Lit 1) | |
fromRational x = Lit (fromRational x) | |
instance Floating Expr where | |
cos = UnaOp Cos | |
sin = UnaOp Sin | |
tan = UnaOp Tan | |
exp = UnaOp Exp | |
(**) = BinOp Pow | |
-- | |
exprNum :: Expr -> (forall a. Floating a => a -> a) | |
exprNum ex x = go ex | |
where | |
go Var = x | |
go (Lit n) = fromIntegral (truncate n) | |
go (BinOp Add x y) = go x + go y | |
go (BinOp Sub x y) = go x - go y | |
go (BinOp Mul x y) = go x * go y | |
go (BinOp Div x y) = go x / go y | |
go (BinOp Pow x y) = go x ** go y | |
go (UnaOp Cos x) = cos (go x) | |
go (UnaOp Sin x) = sin (go x) | |
go (UnaOp Tan x) = tan (go x) | |
go (UnaOp Exp x) = exp (go x) | |
go (UnaOp Log x) = log (go x) | |
numExpr :: (forall a. Floating a => a -> a) -> Expr | |
numExpr f = f Var | |
-- | |
simp :: Expr -> Expr | |
simp Var = Var | |
simp (Lit n) = Lit n | |
-- | |
simp (BinOp Add (Lit 0) x) = simp x | |
simp (BinOp Add x (Lit 0)) = simp x | |
-- | |
simp (BinOp Sub x (Lit 0)) = simp x | |
-- | |
simp (BinOp Mul (Lit 1) x) = simp x | |
simp (BinOp Mul x (Lit 1)) = simp x | |
simp (BinOp Mul (Lit 0) x) = Lit 0 | |
simp (BinOp Mul x (Lit 0)) = Lit 0 | |
-- | |
simp (BinOp Div x (Lit 1)) = simp x | |
-- | |
simp (BinOp Pow x (Lit 0)) = Lit 1 | |
simp (BinOp Pow x (Lit 1)) = simp x | |
-- | |
-- k <> (k' <> x) ==> (k <> k') <> x ==> k'' <> x | |
simp (BinOp op (Lit n) (BinOp op' (Lit m) x)) | |
| op == op' && isAssoc op = simp (BinOp op (Lit (runBinOp op n m)) x) | |
-- | |
simp (BinOp op x (Lit m)) | isCommut op = simp (BinOp op (Lit m) x) | |
-- | |
simp (BinOp op x y) = | |
case (simp x, simp y) of | |
(Lit n, Lit m) -> Lit (runBinOp op n m) | |
(x', Lit m) | isCommut op -> simp (BinOp op (Lit m) x) | |
(x', y') -> BinOp op x' y' | |
-- | |
simp (UnaOp op x) = UnaOp op (simp x) | |
-- | |
simp x = x | |
dd :: Expr -> Expr | |
dd Var = Lit 1 | |
dd (Lit _) = Lit 0 | |
-- d(x + y) = dx + dy | |
dd (BinOp Add x y) = dd x + dd y | |
-- d(x - y) = dx - dy | |
dd (BinOp Sub x y) = dd x - dd y | |
-- d(x * y) = x * dy + y * dx | |
dd (BinOp Mul x y) = x * dd y + y * dd x | |
-- d(d / y) = (x * dy - y * dx) / (y^2) | |
dd (BinOp Div x y) = (x * dd y - y * dd x) / (y ^ 2) | |
-- d(x ^ n) = n * x ^ (n - 1) * dx | |
dd (BinOp Pow x n@Lit{}) = n * x ** (n - 1) * dd x | |
-- d(f(x)) = f'(x) * dx | |
dd (UnaOp Cos x) = -sin x ** dd x | |
dd (UnaOp Sin x) = cos x ** dd x | |
dd (UnaOp Tan x) = (1 + tan x ** 2) * dd x | |
dd (UnaOp Exp x) = exp x * dd x | |
dd (UnaOp Log x) = recip x * dd x | |
-- | |
data Token | |
= LParen | |
| RParen | |
| TSym String | |
| TNum Double | |
deriving (Show, Eq) | |
isSymChar :: Char -> Bool | |
isSymChar c = c `elem` (['a'..'z'] ++ "+-*/^") | |
tokens :: [Char] -> [Token] | |
tokens [] = [] | |
tokens cs = | |
case cs of | |
' ':cs' -> tokens cs' | |
'(':cs' -> LParen : tokens cs' | |
')':cs' -> RParen : tokens cs' | |
_ -> | |
case reads cs of | |
[(n, cs')] -> TNum n : tokens cs' | |
_ -> | |
let (sym, cs') = span isSymChar cs | |
in TSym sym : tokens cs' | |
type Parser a = [Token] -> Maybe (a, [Token]) | |
andThen :: Parser a -> (a -> Parser b) -> Parser b | |
andThen pa fpb ts = | |
case pa ts of | |
Just (a, ts') -> fpb a ts' | |
Nothing -> Nothing | |
(<|>) :: Parser a -> Parser a -> Parser a | |
(<|>) first second ts = | |
case first ts of | |
Just res -> Just res | |
Nothing -> second ts | |
trivial :: a -> Parser a | |
trivial a ts = Just (a, ts) | |
token :: Token -> a -> Parser a | |
token target out ts = | |
case ts of | |
t:ts' | t == target -> Just (out, ts') | |
_ -> Nothing | |
sym :: String -> a -> Parser a | |
sym s = token (TSym s) | |
lit :: [Token] -> Maybe (Expr, [Token]) | |
lit ts = | |
case ts of | |
TNum n : ts' -> Just (Lit n, ts') | |
_ -> Nothing | |
var :: Parser Expr | |
var = sym "x" Var | |
lparen, rparen :: Parser () | |
lparen = token LParen () | |
rparen = token RParen () | |
binop :: Parser (Expr -> Expr -> Expr) | |
binop = | |
sym "+" (BinOp Add) | |
<|> sym "-" (BinOp Sub) | |
<|> sym "*" (BinOp Mul) | |
<|> sym "/" (BinOp Div) | |
<|> sym "^" (BinOp Pow) | |
unaop :: Parser (Expr -> Expr) | |
unaop = | |
sym "cos" (UnaOp Cos) | |
<|> sym "sin" (UnaOp Sin) | |
<|> sym "tan" (UnaOp Tan) | |
<|> sym "exp" (UnaOp Exp) | |
<|> sym "ln" (UnaOp Log) | |
-- something inside parens | |
-- (x y z) | |
-- (x y) | |
nested :: Parser Expr | |
nested = | |
(binop `andThen` \op -> expr `andThen` \x -> expr `andThen` \y -> trivial (op x y)) | |
<|> (unaop `andThen` \op -> expr `andThen` \x -> trivial (op x)) | |
expr :: Parser Expr | |
expr = | |
var | |
<|> lit | |
<|> (lparen `andThen` \_ -> nested `andThen` \x -> rparen `andThen` \_ -> trivial x) | |
parse :: [Token] -> Expr | |
parse ts = | |
case expr ts of | |
Just (x, []) -> x | |
Just (x, xx) -> error (show xx) | |
_ -> error "sadness" | |
pretty :: Expr -> String | |
pretty Var = "x" | |
pretty (Lit n) = | |
if n == fromIntegral (truncate n) | |
then show (truncate n) -- remove decimal point from integers | |
else show n | |
pretty (BinOp Add x y) = "(+ " ++ pretty x ++ " " ++ pretty y ++ ")" | |
pretty (BinOp Sub x y) = "(- " ++ pretty x ++ " " ++ pretty y ++ ")" | |
pretty (BinOp Mul x y) = "(* " ++ pretty x ++ " " ++ pretty y ++ ")" | |
pretty (BinOp Div x y) = "(/ " ++ pretty x ++ " " ++ pretty y ++ ")" | |
pretty (BinOp Pow x y) = "(^ " ++ pretty x ++ " " ++ pretty y ++ ")" | |
pretty (UnaOp Cos x) = "(cos " ++ pretty x ++ ")" | |
pretty (UnaOp Sin x) = "(sin " ++ pretty x ++ ")" | |
pretty (UnaOp Tan x) = "(tan " ++ pretty x ++ ")" | |
pretty (UnaOp Exp x) = "(exp " ++ pretty x ++ ")" | |
pretty (UnaOp Log x) = "(ln " ++ pretty x ++ ")" | |
-- testing property | |
exprIdentity :: Expr -> Expr | |
exprIdentity = parse . tokens . pretty | |
-- interface of the problem | |
diff :: String -> String | |
diff = pretty . simp . dd . parse . tokens |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment