Skip to content

Instantly share code, notes, and snippets.

@madmann91
Last active April 29, 2019 16:52
Show Gist options
  • Save madmann91/ab54c9f22459a87da01a2ffb252c58f8 to your computer and use it in GitHub Desktop.
Save madmann91/ab54c9f22459a87da01a2ffb252c58f8 to your computer and use it in GitHub Desktop.
Simplification of a simple boolean + Comparison language
-- Compile and run with:
-- ghc Simplify.hs
-- ./Simplify
import Data.Bits
import Data.List
import Data.Maybe
import Data.Function
import System.Random
import Control.Monad
import Debug.Trace
import qualified Data.Map as M
import qualified Data.Set as S
import Test.QuickCheck
import Test.QuickCheck.Gen
import Test.QuickCheck.Random
-- Simple boolean language.
type Var = String
type Lit = Int
data Expr = And Expr Expr
| Or Expr Expr
| Not Expr
| Less Var Lit -- Less "x" 2 means x < 2
| Val Bool
| Id Var
deriving(Show, Ord, Eq)
-- Extracts all the Ands in an expression from the top, until something else than And is found.
ands :: Expr -> [Expr]
ands (And a b) = (ands a) ++ (ands b)
ands e = [e]
-- Extracts all the Ors in an expression from the top, until something else than Or is found.
ors :: Expr -> [Expr]
ors (Or a b) = (ors a) ++ (ors b)
ors e = [e]
-- Returns the number of constructors (And, Or, Not, Less, Val, Id) in an expression
size :: Expr -> Int
size (And a b) = 1 + (size a) + (size b)
size (Or a b) = 1 + (size a) + (size b)
size (Not e) = 1 + size e
size _ = 1
-- Computes the depth of an expression (a.k.a the critical path length)
depth :: Expr -> Int
depth (And a b) = 1 + max (depth a) (depth b)
depth (Or a b) = 1 + max (depth a) (depth b)
depth (Not e) = 1 + depth e
depth _ = 1
-- Tests if the first expression is better than the second, prioritizing size over depth
better :: Expr -> Expr -> Bool
better e e' = s < s' || (s == s' && d < d')
where
[s, s'] = map size [e, e']
[d, d'] = map depth [e, e']
-- Chooses between two expression based on which one is better
best :: Expr -> Expr -> Expr
best e e' = if better e e' then e else e'
bestOf :: [Expr] -> Expr
bestOf = foldl1' best
isAnd (And _ _) = True
isAnd _ = False
isOr (Or _ _) = True
isOr _ = False
isNot (Not _) = True
isNot _ = False
foldNot :: Expr -> Expr
foldNot (Val False) = Val True
foldNot (Val True) = Val False
foldNot (Not e) = e
foldNot e = Not e
foldAnd :: Expr -> Expr -> Expr
foldAnd (Val False) _ = (Val False)
foldAnd _ (Val False) = (Val False)
foldAnd (Val True) b = b
foldAnd a (Val True) = a
foldAnd a (Not b) | a == b = Val False
foldAnd (Not a) b | a == b = Val False
foldAnd a b
| a == b = a
| otherwise = if a < b then And a b else And b a
foldOr :: Expr -> Expr -> Expr
foldOr (Val False) b = b
foldOr a (Val False) = a
foldOr (Val True) _ = (Val True)
foldOr _ (Val True) = (Val True)
foldOr a (Not b) | a == b = Val True
foldOr (Not a) b | a == b = Val True
foldOr a b
| a == b = a
| otherwise = if a < b then Or a b else Or b a
-- Produces an expression by Or'ing every expression in the given list.
foldOrs :: [Expr] -> Expr
foldOrs = foldl1' foldOr
-- Produces an expression by And'ing every expression in the given list.
foldAnds :: [Expr] -> Expr
foldAnds = foldl1' foldAnd
-- Creates a predicate that is true when x is in [i, j[.
inRange :: Var -> Lit -> Lit -> Expr
inRange x i j = And (Not $ Less x i) (Less x j)
-- Returns true when the first expression implies the second.
implies :: Expr -> Expr -> Bool
implies a b = implies' (False, a) (False, b)
where
implies' (f, Val a) (g, Val b) = (not c) || (c && d)
where c = a `xor` f
d = b `xor` g
implies' (f, Less v i) (g, Less w j) = v == w && f == g && (if f then j <= i else i <= j)
implies' (f, Not a) b = implies' (not f, a) b
implies' a (f, Not b) = implies' a (not f, b)
implies' (False, And a b) c = (implies' (False, a) c) || (implies' (False, b) c)
implies' (True, And a b) c = (implies' (True, a) c) && (implies' (True, b) c)
implies' (False, Or a b) c = (implies' (False, a) c) && (implies' (False, b) c)
implies' (True, Or a b) c = (implies' (True, a) c) || (implies' (True, b) c)
implies' a (False, And b c) = (implies' a (False, b)) && (implies' a (False, c))
implies' a (True, And b c) = (implies' a (True, b)) || (implies' a (True, c))
implies' a (False, Or b c) = (implies' a (False, b)) || (implies' a (False, c))
implies' a (True, Or b c) = (implies' a (True, b)) && (implies' a (True, c))
implies' a b = a == b
-- Simplifies an expression by removing redundant tests.
-- It works by constructing a knowledge set k, represented as an expression
-- that is assumed to be true.
-- When looking at the expression (And a b), we enter a with the knowledge
-- that b must be true, and vice-versa. This is logical since for the
-- expression to be true, both operands must be true.
-- When looking at the expression (Or a b), we enter a with the knowledge
-- that b must be false, and vice-versa. In this case, we encode the fact
-- that if b was true when entering a, then there is no need to even
-- evaluate a.
simplify :: Int -> Expr -> Expr -> Expr
simplify = rewrite
where
rewrite d k e | d > 0 = bestOf $ apply d k [commute, distribute, factorize, deMorgan] $ reduce d k e
rewrite d k e = reduce d k e
reduce d k (And a b) = foldAnd a' b'
where
d' = d - 1
a' = rewrite d' (foldAnd k b ) a
b' = rewrite d' (foldAnd k a') b
reduce d k (Or a b) = foldOr a' b'
where
d' = d - 1
a' = rewrite d' (foldAnd k $ foldNot b ) a
b' = rewrite d' (foldAnd k $ foldNot a') b
reduce d k (Not e) = foldNot $ rewrite (d - 1) k e
reduce d _ (Val b) = Val b
reduce d k e =
if implies k e then Val True
else if implies k (Not e) then Val False
else e
apply d k ts e = foldl' (\l t -> t d k l e) [e] ts
commute d' k l (And (And a b) (And c d)) = (reduce d' k $ foldAnd (foldAnd a c) (foldAnd b d)) : (reduce d' k $ foldAnd (foldAnd a d) (foldAnd b c)) : l
commute d' k l (Or (Or a b) (Or c d)) = (reduce d' k $ foldOr (foldOr a c) (foldOr b d)) : (reduce d' k $ foldOr (foldOr a d) (foldOr b c)) : l
commute d k l e = commuteR d k (commuteL d k l e) e
commuteL d k l (And (And a b) c) = (reduce d k $ foldAnd a $ foldAnd b c) : (reduce d k $ foldAnd b $ foldAnd a c) : l
commuteL d k l (Or (Or a b) c) = (reduce d k $ foldOr a $ foldOr b c) : (reduce d k $ foldOr b $ foldOr a c) : l
commuteL d k l _ = l
commuteR d k l (And a (And b c)) = commuteL d k l (And (And b c) a)
commuteR d k l (Or a (Or b c)) = commuteL d k l (Or (Or b c) a)
commuteR d k l _ = l
distribute d k l e = distributeR d k (distributeL d k l e) e
distributeL d k l (And (Or a b) c) = (reduce d k $ foldOr (foldAnd a c) (foldAnd b c)) : l
distributeL d k l (Or (And a b) c) = (reduce d k $ foldAnd (foldOr a c) (foldOr b c)) : l
distributeL d k l _ = l
distributeR d k l (And a (Or b c)) = distributeL d k l (And (Or b c) a)
distributeR d k l (Or a (And b c)) = distributeL d k l (Or (And b c) a)
distributeR d k l _ = l
factorize d' k l (And (Or a b) (Or c d)) = foldl (factorizeAnd d' k) l [(a, b, c, d), (b, a, c, d), (a, b, d, c), (b, a, d, c)]
factorize d' k l (Or (And a b) (And c d)) = foldl (factorizeOr d' k) l [(a, b, c, d), (b, a, c, d), (a, b, d, c), (b, a, d, c)]
factorize d' k l _ = l
factorizeAnd d' k l (a, b, c, d) | a == c = (reduce d' k $ foldOr a $ foldAnd b d) : l
factorizeAnd d' k l _ = l
factorizeOr d' k l (a, b, c, d) | a == c = (reduce d' k $ foldAnd a $ foldOr b d) : l
factorizeOr d' k l _ = l
deMorgan d k l (And a b) = (foldNot $ reduce d k $ foldOr (foldNot a) (foldNot b)) : l
deMorgan d k l (Or a b) = (foldNot $ reduce d k $ foldAnd (foldNot a) (foldNot b)) : l
deMorgan d k l (Not (And a b)) = (rewrite d k $ foldOr (foldNot a) (foldNot b)) : l
deMorgan d k l (Not (Or a b)) = (rewrite d k $ foldAnd (foldNot a) (foldNot b)) : l
deMorgan d k l _ = l
-- Random generator for expressions
instance Arbitrary Expr where
arbitrary = sized arbitrary'
where
arbitrary' 0 = oneof [arbitraryLess, arbitraryVal, arbitraryId]
arbitrary' n | n > 0 = oneof [arbitraryAnd n, arbitraryOr n, arbitraryNot n]
arbitraryAnd n = foldAnd <$> (arbitrary' m) <*> (arbitrary' (n - m)) where m = n `div` 2
arbitraryOr n = foldOr <$> (arbitrary' m) <*> (arbitrary' (n - m)) where m = n `div` 2
arbitraryNot n = foldNot <$> (arbitrary' (n - 1))
arbitraryLess = Less <$> (elements ["x", "y"]) <*> (choose (1, 4))
arbitraryVal = Val <$> arbitrary
arbitraryId = Id <$> (elements ["a", "b"])
between x i j = x >= i && x <= j
shrink (And a b) = [a, b]
shrink (Or a b) = [a, b]
shrink (Not e) = [e]
shrink e = []
prettyPrint :: Expr -> String
prettyPrint e@(And _ _) = ('(' : (concat $ intersperse " & " $ map prettyPrint (ands e))) ++ ")"
prettyPrint e@(Or _ _) = ('(' : (concat $ intersperse " | " $ map prettyPrint (ors e))) ++ ")"
prettyPrint (Not (Val b)) = if b then "0" else "1"
prettyPrint (Not (Less x i)) = "(" ++ x ++ " >= " ++ (show i) ++ ")"
prettyPrint (Not e) = '!' : (prettyPrint e)
prettyPrint (Less x i) = "(" ++ x ++ " < " ++ (show i) ++ ")"
prettyPrint (Val b) = if b then "1" else "0"
prettyPrint (Id x) = x
prettyPrintZ3 :: Expr -> String
prettyPrintZ3 (And a b) = "(and " ++ (prettyPrintZ3) a ++ " " ++ (prettyPrintZ3) b ++ ")"
prettyPrintZ3 (Or a b) = "(or " ++ (prettyPrintZ3) a ++ " " ++ (prettyPrintZ3) b ++ ")"
prettyPrintZ3 (Not (Val b)) = if b then "false" else "true"
prettyPrintZ3 (Not e) = "(not " ++ (prettyPrintZ3 e) ++ ")"
prettyPrintZ3 (Less x i) = "(< " ++ x ++ " " ++ (show i) ++ ")"
prettyPrintZ3 (Val b) = if b then "true" else "false"
prettyPrintZ3 (Id x) = x
cnf :: Expr -> Expr
cnf (And a b) = foldAnd (cnf a) (cnf b)
cnf (Or a b) = distribute a' b'
where
distribute (And x y) (And z w) = foldAnds [distribute x z, distribute x w, distribute y z, distribute y w]
distribute (And x y) z = foldAnd (distribute x z) (distribute y z)
distribute x (And y z) = foldAnd (distribute x y) (distribute x z)
distribute x y = foldOr x y
a' = (cnf $ a)
b' = (cnf $ b)
cnf (Not (And a b)) = cnf $ foldOr (foldNot a) (foldNot b)
cnf (Not (Or a b)) = foldAnd (cnf $ foldNot a) (cnf $ foldNot b)
cnf e = e
eval :: M.Map String Bool -> M.Map String Int -> Expr -> Expr
eval v w (And a b) = foldAnd (eval v w a) (eval v w b)
eval v w (Or a b) = foldOr (eval v w a) (eval v w b)
eval v w (Not e) = foldNot $ eval v w e
eval v w e@(Less x i) = case M.lookup x w of
Just j -> Val $ j < i
_ -> e
eval v w e@(Id x) = case M.lookup x v of
Just b -> Val b
_ -> e
versus :: QCGen -> Int -> (Expr -> Expr) -> (Expr -> Expr) -> (Int, [Expr])
versus q n s1 s2 = snd $ foldl match (q, (0, [])) [1..n]
where
MkGen gen = arbitrary :: Gen Expr
match (q, (s, f)) _ = (q', (s', f'))
where
e = gen q 30
e1 = s1 e
e2 = s2 e
(s', f') = (s + if better e1 e2 then 1 else 0, if better e2 e1 then e:f else f)
q' = snd $ next $ q
-- Simplifies the expression: ((x >= 1) & (x < 2)) | ((x >= 2) & (x < 3)) | ((x >= 4) & (x < 5))
-- This represents the following union of ranges:
--
-- |----| ((x >= 4) & (x < 5))
-- |----| ((x >= 2) & (x < 3))
-- |----| ((x >= 1) & (x < 2))
-- |----|----|----|----|
-- 1 2 3 4 5
--
-- The result should only contain 4 comparisons (with 1, 3, 4, and 5).
maxDepth = 10
main = print $ simplify maxDepth (Val True) $ foldOrs [inRange "x" 1 2, inRange "x" 2 3, inRange "x" 4 5]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment