Skip to content

Instantly share code, notes, and snippets.

@Garciat
Created May 14, 2021 03:04
Show Gist options
  • Save Garciat/ee399adc6d11b58b3527fe35eaefc35c to your computer and use it in GitHub Desktop.
Save Garciat/ee399adc6d11b58b3527fe35eaefc35c to your computer and use it in GitHub Desktop.
{-# LANGUAGE RankNTypes #-}
module Symbolic where
data Expr
= Var -- x
| Lit Double
| BinOp BinOp Expr Expr
| UnaOp UnaOp Expr
deriving (Show, Eq)
data BinOp
= Add
| Sub
| Mul
| Div
| Pow
deriving (Show, Eq)
-- a <> b = b <> a
isCommut :: BinOp -> Bool
isCommut Add = True
isCommut Mul = True
isCommut _ = False
-- (a <> b) <> c = a <> (b <> c) = a <> b <> c
isAssoc :: BinOp -> Bool
isAssoc Add = True
isAssoc Mul = True
isAssoc _ = False
runBinOp :: BinOp -> Double -> Double -> Double
runBinOp Add = (+)
runBinOp Sub = (-)
runBinOp Mul = (*)
runBinOp Div = (/)
runBinOp Pow = (**)
data UnaOp
= Cos
| Sin
| Tan
| Exp
| Log -- ln
deriving (Show, Eq)
--
instance Num Expr where
(+) = BinOp Add
(-) = BinOp Sub
(*) = BinOp Mul
negate = BinOp Mul (Lit (-1))
fromInteger i = Lit (fromIntegral i)
instance Fractional Expr where
(/) = BinOp Div
recip = BinOp Div (Lit 1)
fromRational x = Lit (fromRational x)
instance Floating Expr where
cos = UnaOp Cos
sin = UnaOp Sin
tan = UnaOp Tan
exp = UnaOp Exp
(**) = BinOp Pow
--
exprNum :: Expr -> (forall a. Floating a => a -> a)
exprNum ex x = go ex
where
go Var = x
go (Lit n) = fromIntegral (truncate n)
go (BinOp Add x y) = go x + go y
go (BinOp Sub x y) = go x - go y
go (BinOp Mul x y) = go x * go y
go (BinOp Div x y) = go x / go y
go (BinOp Pow x y) = go x ** go y
go (UnaOp Cos x) = cos (go x)
go (UnaOp Sin x) = sin (go x)
go (UnaOp Tan x) = tan (go x)
go (UnaOp Exp x) = exp (go x)
go (UnaOp Log x) = log (go x)
numExpr :: (forall a. Floating a => a -> a) -> Expr
numExpr f = f Var
--
simp :: Expr -> Expr
simp Var = Var
simp (Lit n) = Lit n
--
simp (BinOp Add (Lit 0) x) = simp x
simp (BinOp Add x (Lit 0)) = simp x
--
simp (BinOp Sub x (Lit 0)) = simp x
--
simp (BinOp Mul (Lit 1) x) = simp x
simp (BinOp Mul x (Lit 1)) = simp x
simp (BinOp Mul (Lit 0) x) = Lit 0
simp (BinOp Mul x (Lit 0)) = Lit 0
--
simp (BinOp Div x (Lit 1)) = simp x
--
simp (BinOp Pow x (Lit 0)) = Lit 1
simp (BinOp Pow x (Lit 1)) = simp x
--
-- k <> (k' <> x) ==> (k <> k') <> x ==> k'' <> x
simp (BinOp op (Lit n) (BinOp op' (Lit m) x))
| op == op' && isAssoc op = simp (BinOp op (Lit (runBinOp op n m)) x)
--
simp (BinOp op x (Lit m)) | isCommut op = simp (BinOp op (Lit m) x)
--
simp (BinOp op x y) =
case (simp x, simp y) of
(Lit n, Lit m) -> Lit (runBinOp op n m)
(x', Lit m) | isCommut op -> simp (BinOp op (Lit m) x)
(x', y') -> BinOp op x' y'
--
simp (UnaOp op x) = UnaOp op (simp x)
--
simp x = x
dd :: Expr -> Expr
dd Var = Lit 1
dd (Lit _) = Lit 0
-- d(x + y) = dx + dy
dd (BinOp Add x y) = dd x + dd y
-- d(x - y) = dx - dy
dd (BinOp Sub x y) = dd x - dd y
-- d(x * y) = x * dy + y * dx
dd (BinOp Mul x y) = x * dd y + y * dd x
-- d(d / y) = (x * dy - y * dx) / (y^2)
dd (BinOp Div x y) = (x * dd y - y * dd x) / (y ^ 2)
-- d(x ^ n) = n * x ^ (n - 1) * dx
dd (BinOp Pow x n@Lit{}) = n * x ** (n - 1) * dd x
-- d(f(x)) = f'(x) * dx
dd (UnaOp Cos x) = -sin x ** dd x
dd (UnaOp Sin x) = cos x ** dd x
dd (UnaOp Tan x) = (1 + tan x ** 2) * dd x
dd (UnaOp Exp x) = exp x * dd x
dd (UnaOp Log x) = recip x * dd x
--
data Token
= LParen
| RParen
| TSym String
| TNum Double
deriving (Show, Eq)
isSymChar :: Char -> Bool
isSymChar c = c `elem` (['a'..'z'] ++ "+-*/^")
tokens :: [Char] -> [Token]
tokens [] = []
tokens cs =
case cs of
' ':cs' -> tokens cs'
'(':cs' -> LParen : tokens cs'
')':cs' -> RParen : tokens cs'
_ ->
case reads cs of
[(n, cs')] -> TNum n : tokens cs'
_ ->
let (sym, cs') = span isSymChar cs
in TSym sym : tokens cs'
type Parser a = [Token] -> Maybe (a, [Token])
andThen :: Parser a -> (a -> Parser b) -> Parser b
andThen pa fpb ts =
case pa ts of
Just (a, ts') -> fpb a ts'
Nothing -> Nothing
(<|>) :: Parser a -> Parser a -> Parser a
(<|>) first second ts =
case first ts of
Just res -> Just res
Nothing -> second ts
trivial :: a -> Parser a
trivial a ts = Just (a, ts)
token :: Token -> a -> Parser a
token target out ts =
case ts of
t:ts' | t == target -> Just (out, ts')
_ -> Nothing
sym :: String -> a -> Parser a
sym s = token (TSym s)
lit :: [Token] -> Maybe (Expr, [Token])
lit ts =
case ts of
TNum n : ts' -> Just (Lit n, ts')
_ -> Nothing
var :: Parser Expr
var = sym "x" Var
lparen, rparen :: Parser ()
lparen = token LParen ()
rparen = token RParen ()
binop :: Parser (Expr -> Expr -> Expr)
binop =
sym "+" (BinOp Add)
<|> sym "-" (BinOp Sub)
<|> sym "*" (BinOp Mul)
<|> sym "/" (BinOp Div)
<|> sym "^" (BinOp Pow)
unaop :: Parser (Expr -> Expr)
unaop =
sym "cos" (UnaOp Cos)
<|> sym "sin" (UnaOp Sin)
<|> sym "tan" (UnaOp Tan)
<|> sym "exp" (UnaOp Exp)
<|> sym "ln" (UnaOp Log)
-- something inside parens
-- (x y z)
-- (x y)
nested :: Parser Expr
nested =
(binop `andThen` \op -> expr `andThen` \x -> expr `andThen` \y -> trivial (op x y))
<|> (unaop `andThen` \op -> expr `andThen` \x -> trivial (op x))
expr :: Parser Expr
expr =
var
<|> lit
<|> (lparen `andThen` \_ -> nested `andThen` \x -> rparen `andThen` \_ -> trivial x)
parse :: [Token] -> Expr
parse ts =
case expr ts of
Just (x, []) -> x
Just (x, xx) -> error (show xx)
_ -> error "sadness"
pretty :: Expr -> String
pretty Var = "x"
pretty (Lit n) =
if n == fromIntegral (truncate n)
then show (truncate n) -- remove decimal point from integers
else show n
pretty (BinOp Add x y) = "(+ " ++ pretty x ++ " " ++ pretty y ++ ")"
pretty (BinOp Sub x y) = "(- " ++ pretty x ++ " " ++ pretty y ++ ")"
pretty (BinOp Mul x y) = "(* " ++ pretty x ++ " " ++ pretty y ++ ")"
pretty (BinOp Div x y) = "(/ " ++ pretty x ++ " " ++ pretty y ++ ")"
pretty (BinOp Pow x y) = "(^ " ++ pretty x ++ " " ++ pretty y ++ ")"
pretty (UnaOp Cos x) = "(cos " ++ pretty x ++ ")"
pretty (UnaOp Sin x) = "(sin " ++ pretty x ++ ")"
pretty (UnaOp Tan x) = "(tan " ++ pretty x ++ ")"
pretty (UnaOp Exp x) = "(exp " ++ pretty x ++ ")"
pretty (UnaOp Log x) = "(ln " ++ pretty x ++ ")"
-- testing property
exprIdentity :: Expr -> Expr
exprIdentity = parse . tokens . pretty
-- interface of the problem
diff :: String -> String
diff = pretty . simp . dd . parse . tokens
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment