Last active
January 23, 2026 20:36
-
-
Save mchav/6747a2ab3db17745faa55ef0aba3f983 to your computer and use it in GitHub Desktop.
Now dynamic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env cabal | |
| {- cabal: | |
| build-depends: base >= 4, dataframe, hegg, text | |
| -} | |
| {-# LANGUAGE OverloadedStrings #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| {-# LANGUAGE DeriveFunctor #-} | |
| {-# LANGUAGE DeriveFoldable #-} | |
| {-# LANGUAGE DeriveTraversable #-} | |
| {-# LANGUAGE FlexibleContexts #-} | |
| {-# LANGUAGE FlexibleInstances #-} | |
| {-# LANGUAGE GADTs #-} | |
| {-# LANGUAGE ScopedTypeVariables #-} | |
| {-# LANGUAGE LambdaCase #-} | |
| {-# LANGUAGE MultiParamTypeClasses #-} | |
| {-# LANGUAGE RankNTypes #-} | |
| {-# LANGUAGE ExplicitNamespaces #-} | |
| {-# LANGUAGE AllowAmbiguousTypes #-} | |
| module Main where | |
| import Control.Applicative | |
| import Data.Equality.Saturation | |
| import Data.Equality.Analysis | |
| import Data.Equality.Graph hiding (add) | |
| import Data.Equality.Graph.Lens ((^.), _class, _data) | |
| import Data.Equality.Matching | |
| import qualified Data.Text as T | |
| import Data.Type.Equality | |
| import Type.Reflection | |
| -- Now using reflection instead of a closed type universe! | |
| -- No more Ty or STy types needed. | |
| data SomeVal where | |
| SomeVal :: (Typeable t, Show t, Eq t, Ord t) => t -> SomeVal | |
| instance Show SomeVal where | |
| show (SomeVal v) = show v | |
| instance Eq SomeVal where | |
| SomeVal (v1 :: a) == SomeVal (v2 :: b) = | |
| case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> v1 == v2 | |
| Nothing -> False | |
| instance Ord SomeVal where | |
| compare (SomeVal (v1 :: a)) (SomeVal (v2 :: b)) = | |
| case compare (SomeTypeRep $ typeOf v1) (SomeTypeRep $ typeOf v2) of | |
| EQ -> case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> compare v1 v2 | |
| Nothing -> EQ | |
| x -> x | |
| data SymExpr a | |
| = SLit SomeVal | |
| | SCol SomeTypeRep String | |
| | SUnaryOp String SomeTypeRep SomeTypeRep a | |
| | SBinaryOp String SomeTypeRep SomeTypeRep SomeTypeRep a a | |
| deriving (Eq, Ord, Show, Functor, Foldable, Traversable) | |
| exprTy :: SymExpr a -> SomeTypeRep | |
| exprTy (SLit (SomeVal (v :: t))) = SomeTypeRep (typeOf v) | |
| exprTy (SCol ty _) = ty | |
| exprTy (SUnaryOp _ _ out _) = out | |
| exprTy (SBinaryOp _ _ _ out _ _) = out | |
| newtype TExpr t = TExpr { unTExpr :: Fix SymExpr } | |
| lit :: forall t. (Typeable t, Show t, Eq t, Ord t) => t -> TExpr t | |
| lit val = TExpr $ Fix $ SLit (SomeVal val) | |
| litDouble :: Double -> TExpr Double | |
| litDouble = lit | |
| litInt :: Int -> TExpr Int | |
| litInt = lit | |
| litBool :: Bool -> TExpr Bool | |
| litBool = lit | |
| col :: forall t. Typeable t => String -> TExpr t | |
| col name = TExpr $ Fix $ SCol (SomeTypeRep (typeRep @t)) name | |
| colDouble :: String -> TExpr Double | |
| colDouble = col | |
| colInt :: String -> TExpr Int | |
| colInt = col | |
| unaryOp :: forall a b. (Typeable a, Typeable b) | |
| => String -> (a -> b) -> TExpr a -> TExpr b | |
| unaryOp name _f (TExpr child) = | |
| TExpr $ Fix $ SUnaryOp name (SomeTypeRep (typeRep @a)) (SomeTypeRep (typeRep @b)) child | |
| binaryOp :: forall a b c. (Typeable a, Typeable b, Typeable c) | |
| => String -> (a -> b -> c) -> TExpr a -> TExpr b -> TExpr c | |
| binaryOp name _f (TExpr left) (TExpr right) = | |
| TExpr $ Fix $ SBinaryOp name | |
| (SomeTypeRep (typeRep @a)) | |
| (SomeTypeRep (typeRep @b)) | |
| (SomeTypeRep (typeRep @c)) | |
| left right | |
| add :: (Typeable a, Num a) => TExpr a -> TExpr a -> TExpr a | |
| add = binaryOp "add" (+) | |
| mult :: (Typeable a, Num a) => TExpr a -> TExpr a -> TExpr a | |
| mult = binaryOp "mult" (*) | |
| divide :: (Typeable a, Fractional a) => TExpr a -> TExpr a -> TExpr a | |
| divide = binaryOp "divide" (/) | |
| cosine :: (Typeable a, Floating a) => TExpr a -> TExpr a | |
| cosine = unaryOp "cos" cos | |
| sine :: (Typeable a, Floating a) => TExpr a -> TExpr a | |
| sine = unaryOp "sin" sin | |
| intToDouble :: TExpr Int -> TExpr Double | |
| intToDouble = unaryOp "intToDouble" fromIntegral | |
| doubleToInt :: TExpr Double -> TExpr Int | |
| doubleToInt = unaryOp "doubleToInt" floor | |
| intToInteger :: TExpr Int -> TExpr Integer | |
| intToInteger = unaryOp "intToInteger" fromIntegral | |
| eqOp :: (Typeable a, Eq a) => TExpr a -> TExpr a -> TExpr Bool | |
| eqOp = binaryOp "eq" (==) | |
| gtOp :: (Typeable a, Ord a) => TExpr a -> TExpr a -> TExpr Bool | |
| gtOp = binaryOp "gt" (>) | |
| ltOp :: (Typeable a, Ord a) => TExpr a -> TExpr a -> TExpr Bool | |
| ltOp = binaryOp "lt" (<) | |
| instance Analysis (Maybe SomeVal) SymExpr where | |
| makeA = \case | |
| SLit v -> Just v | |
| SCol _ _ -> Nothing | |
| SUnaryOp name _ _ child -> do | |
| val <- child | |
| evalUnary name val | |
| SBinaryOp name _ _ _ left right -> do | |
| l <- left | |
| r <- right | |
| evalBinary name l r | |
| joinA left right = case (left, right) of | |
| (Just l, Just r) -> if l == r then Just l else error "inconsistent analysis" | |
| _ -> asum [left, right] | |
| modifyA c egr = case egr ^._class c._data of | |
| Nothing -> egr | |
| Just v -> let (c', egr') = represent (Fix (SLit v)) egr | |
| in snd $ merge c c' egr' | |
| applyNumOp :: forall a . (Typeable a, Num a) => (a -> a -> a) -> SomeVal -> SomeVal -> Maybe SomeVal | |
| applyNumOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| Refl <- testEquality (typeRep @b) (typeRep @c) | |
| return $ SomeVal (op x y) | |
| applyFracOp :: forall a . (Typeable a, Fractional a) => (a -> a -> a) -> SomeVal -> SomeVal -> Maybe SomeVal | |
| applyFracOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| Refl <- testEquality (typeRep @b) (typeRep @c) | |
| return $ SomeVal (op x y) | |
| applyOrdOp :: forall a . (Typeable a, Ord a) => (a -> a -> Bool) -> SomeVal -> SomeVal -> Maybe SomeVal | |
| applyOrdOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| Refl <- testEquality (typeRep @b) (typeRep @c) | |
| return $ SomeVal (op x y) | |
| evalUnary :: String -> SomeVal -> Maybe SomeVal | |
| evalUnary "cos" (SomeVal (d :: t)) = do | |
| Refl <- testEquality (typeRep @t) (typeRep @Double) | |
| return $ SomeVal (cos d) | |
| evalUnary "sin" (SomeVal (d :: t)) = do | |
| Refl <- testEquality (typeRep @t) (typeRep @Double) | |
| return $ SomeVal (sin d) | |
| evalUnary "intToDouble" (SomeVal (i :: t)) = do | |
| Refl <- testEquality (typeRep @t) (typeRep @Int) | |
| return $ SomeVal (fromIntegral i :: Double) | |
| evalUnary "intToInteger" (SomeVal (i :: t)) = do | |
| Refl <- testEquality (typeRep @t) (typeRep @Int) | |
| return $ SomeVal (fromIntegral i :: Integer) | |
| evalUnary "doubleToInt" (SomeVal (d :: t)) = do | |
| Refl <- testEquality (typeRep @t) (typeRep @Double) | |
| return $ SomeVal (floor d :: Int) | |
| evalUnary _ _ = Nothing | |
| evalBinary :: String -> SomeVal -> SomeVal -> Maybe SomeVal | |
| evalBinary "add" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) = | |
| case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case () of | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Int) -> Just $ SomeVal (x + y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x + y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Integer) -> Just $ SomeVal (x + y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x + y) | |
| _ -> Nothing | |
| Nothing -> Nothing | |
| evalBinary "mult" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) = | |
| case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case () of | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Int) -> Just $ SomeVal (x * y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x * y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Integer) -> Just $ SomeVal (x * y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x * y) | |
| _ -> Nothing | |
| Nothing -> Nothing | |
| evalBinary "divide" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) = | |
| case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case () of | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x / y) | |
| _ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x / y) | |
| _ -> Nothing | |
| Nothing -> Nothing | |
| evalBinary "eq" (SomeVal (x :: a)) (SomeVal (y :: b)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| return $ SomeVal (x == y) | |
| evalBinary "gt" (SomeVal (x :: a)) (SomeVal (y :: b)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| return $ SomeVal (x > y) | |
| evalBinary "lt" (SomeVal (x :: a)) (SomeVal (y :: b)) = do | |
| Refl <- testEquality (typeRep @a) (typeRep @b) | |
| return $ SomeVal (x < y) | |
| evalBinary _ _ _ = Nothing | |
| cost :: CostFunction SymExpr Int | |
| cost = \case | |
| SLit _ -> 1 | |
| SCol _ _ -> 1 | |
| SUnaryOp _ _ _ c -> 1 + c | |
| SBinaryOp _ _ _ _ c1 c2 -> c1 + c2 + 2 | |
| pBinOp :: String -> SomeTypeRep -> SomeTypeRep -> SomeTypeRep | |
| -> Pattern SymExpr -> Pattern SymExpr -> Pattern SymExpr | |
| pBinOp name t1 t2 t3 p1 p2 = pat (SBinaryOp name t1 t2 t3 p1 p2) | |
| pLitDouble :: Double -> Pattern SymExpr | |
| pLitDouble d = pat (SLit (SomeVal d)) | |
| pLitInt :: Int -> Pattern SymExpr | |
| pLitInt i = pat (SLit (SomeVal i)) | |
| -- Generic rewrite rules that work for any numeric type | |
| mkNumericRewrites :: forall a anl . Typeable a => [Rewrite anl SymExpr] | |
| mkNumericRewrites = | |
| let t = SomeTypeRep (typeRep @a) | |
| in [ -- (a * b) / c => a * (b / c) | |
| pBinOp "divide" t t t | |
| (pBinOp "mult" t t t "a" "b") "c" | |
| := pBinOp "mult" t t t "a" | |
| (pBinOp "divide" t t t "b" "c") | |
| , -- x / x => 1 | |
| pBinOp "divide" t t t "x" "x" := "one" | |
| , -- x * 1 => x | |
| pBinOp "mult" t t t "x" "one" := "x" | |
| , -- 1 * x => x | |
| pBinOp "mult" t t t "one" "x" := "x" | |
| ] | |
| rewrites :: [Rewrite anl SymExpr] | |
| rewrites = | |
| let tDouble = SomeTypeRep (typeRep @Double) | |
| tInt = SomeTypeRep (typeRep @Int) | |
| in [ | |
| pBinOp "divide" tDouble tDouble tDouble | |
| (pBinOp "mult" tDouble tDouble tDouble "a" "b") "c" | |
| := pBinOp "mult" tDouble tDouble tDouble "a" | |
| (pBinOp "divide" tDouble tDouble tDouble "b" "c") | |
| , pBinOp "divide" tDouble tDouble tDouble "x" "x" := pLitDouble 1.0 | |
| , pBinOp "mult" tDouble tDouble tDouble "x" (pLitDouble 1.0) := "x" | |
| , pBinOp "mult" tDouble tDouble tDouble (pLitDouble 1.0) "x" := "x" | |
| , pBinOp "mult" tInt tInt tInt "x" (pLitInt 1) := "x" | |
| , pBinOp "mult" tInt tInt tInt (pLitInt 1) "x" := "x" | |
| ] | |
| e1 :: TExpr Double | |
| e1 = divide (mult (col "x") (lit 2)) (lit 2) | |
| e2 :: TExpr Int | |
| e2 = add (mult (col "count") (lit 5)) (lit 10) | |
| e3 :: TExpr Double | |
| e3 = add (intToDouble (col "count")) (lit 1.0) | |
| e4 :: TExpr Bool | |
| e4 = gtOp (intToDouble (col "x")) (lit 10.0) | |
| e5 :: TExpr Integer | |
| e5 = add (mult (col "bignum") (lit (100 :: Integer))) (lit (50 :: Integer)) | |
| intExample :: TExpr Int | |
| intExample = add (mult (col "a") (lit 2)) (lit 3) | |
| doubleExample :: TExpr Double | |
| doubleExample = add (mult (col "a") (lit 2.0)) (lit 3.0) | |
| main :: IO () | |
| main = do | |
| putStrLn "=== Original expression e1: (x * 2.0) / 2.0 (Double) ===" | |
| print (unTExpr e1) | |
| putStrLn "\n=== After equality saturation ===" | |
| let (optimized, _egraph) = equalitySaturation @(Maybe SomeVal) (unTExpr e1) rewrites cost | |
| print optimized | |
| putStrLn "\n=== Integer expression e2: (count * 5) + 10 ===" | |
| print (unTExpr e2) | |
| let (opt2, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e2) rewrites cost | |
| print opt2 | |
| putStrLn "\n=== Heterogeneous expression e3: intToDouble(count) + 1.0 ===" | |
| print (unTExpr e3) | |
| putStrLn "\n=== Comparison expression e4: intToDouble(x) > 10.0 ===" | |
| print (unTExpr e4) | |
| putStrLn "\n=== Constant folding test: (2 * 3) + 5 (Int) ===" | |
| let e6 = add (mult (lit @Int 2) (lit 3)) (lit 5) | |
| let (folded, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e6) rewrites cost | |
| print folded | |
| putStrLn "Expected: 11" | |
| putStrLn "\n=== Constant folding test: (2.0 * 3.0) + 5.0 (Double) ===" | |
| let e7 = add (mult (lit @Double 2.0) (lit 3.0)) (lit 5.0) | |
| let (folded2, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e7) rewrites cost | |
| print folded2 | |
| putStrLn "Expected: 11.0" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment