Skip to content

Instantly share code, notes, and snippets.

@mchav
Last active January 23, 2026 20:36
Show Gist options
  • Select an option

  • Save mchav/6747a2ab3db17745faa55ef0aba3f983 to your computer and use it in GitHub Desktop.

Select an option

Save mchav/6747a2ab3db17745faa55ef0aba3f983 to your computer and use it in GitHub Desktop.
Now dynamic
#!/usr/bin/env cabal
{- cabal:
build-depends: base >= 4, dataframe, hegg, text
-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Main where
import Control.Applicative
import Data.Equality.Saturation
import Data.Equality.Analysis
import Data.Equality.Graph hiding (add)
import Data.Equality.Graph.Lens ((^.), _class, _data)
import Data.Equality.Matching
import qualified Data.Text as T
import Data.Type.Equality
import Type.Reflection
-- Now using reflection instead of a closed type universe!
-- No more Ty or STy types needed.
data SomeVal where
SomeVal :: (Typeable t, Show t, Eq t, Ord t) => t -> SomeVal
instance Show SomeVal where
show (SomeVal v) = show v
instance Eq SomeVal where
SomeVal (v1 :: a) == SomeVal (v2 :: b) =
case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> v1 == v2
Nothing -> False
instance Ord SomeVal where
compare (SomeVal (v1 :: a)) (SomeVal (v2 :: b)) =
case compare (SomeTypeRep $ typeOf v1) (SomeTypeRep $ typeOf v2) of
EQ -> case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> compare v1 v2
Nothing -> EQ
x -> x
data SymExpr a
= SLit SomeVal
| SCol SomeTypeRep String
| SUnaryOp String SomeTypeRep SomeTypeRep a
| SBinaryOp String SomeTypeRep SomeTypeRep SomeTypeRep a a
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
exprTy :: SymExpr a -> SomeTypeRep
exprTy (SLit (SomeVal (v :: t))) = SomeTypeRep (typeOf v)
exprTy (SCol ty _) = ty
exprTy (SUnaryOp _ _ out _) = out
exprTy (SBinaryOp _ _ _ out _ _) = out
newtype TExpr t = TExpr { unTExpr :: Fix SymExpr }
lit :: forall t. (Typeable t, Show t, Eq t, Ord t) => t -> TExpr t
lit val = TExpr $ Fix $ SLit (SomeVal val)
litDouble :: Double -> TExpr Double
litDouble = lit
litInt :: Int -> TExpr Int
litInt = lit
litBool :: Bool -> TExpr Bool
litBool = lit
col :: forall t. Typeable t => String -> TExpr t
col name = TExpr $ Fix $ SCol (SomeTypeRep (typeRep @t)) name
colDouble :: String -> TExpr Double
colDouble = col
colInt :: String -> TExpr Int
colInt = col
unaryOp :: forall a b. (Typeable a, Typeable b)
=> String -> (a -> b) -> TExpr a -> TExpr b
unaryOp name _f (TExpr child) =
TExpr $ Fix $ SUnaryOp name (SomeTypeRep (typeRep @a)) (SomeTypeRep (typeRep @b)) child
binaryOp :: forall a b c. (Typeable a, Typeable b, Typeable c)
=> String -> (a -> b -> c) -> TExpr a -> TExpr b -> TExpr c
binaryOp name _f (TExpr left) (TExpr right) =
TExpr $ Fix $ SBinaryOp name
(SomeTypeRep (typeRep @a))
(SomeTypeRep (typeRep @b))
(SomeTypeRep (typeRep @c))
left right
add :: (Typeable a, Num a) => TExpr a -> TExpr a -> TExpr a
add = binaryOp "add" (+)
mult :: (Typeable a, Num a) => TExpr a -> TExpr a -> TExpr a
mult = binaryOp "mult" (*)
divide :: (Typeable a, Fractional a) => TExpr a -> TExpr a -> TExpr a
divide = binaryOp "divide" (/)
cosine :: (Typeable a, Floating a) => TExpr a -> TExpr a
cosine = unaryOp "cos" cos
sine :: (Typeable a, Floating a) => TExpr a -> TExpr a
sine = unaryOp "sin" sin
intToDouble :: TExpr Int -> TExpr Double
intToDouble = unaryOp "intToDouble" fromIntegral
doubleToInt :: TExpr Double -> TExpr Int
doubleToInt = unaryOp "doubleToInt" floor
intToInteger :: TExpr Int -> TExpr Integer
intToInteger = unaryOp "intToInteger" fromIntegral
eqOp :: (Typeable a, Eq a) => TExpr a -> TExpr a -> TExpr Bool
eqOp = binaryOp "eq" (==)
gtOp :: (Typeable a, Ord a) => TExpr a -> TExpr a -> TExpr Bool
gtOp = binaryOp "gt" (>)
ltOp :: (Typeable a, Ord a) => TExpr a -> TExpr a -> TExpr Bool
ltOp = binaryOp "lt" (<)
instance Analysis (Maybe SomeVal) SymExpr where
makeA = \case
SLit v -> Just v
SCol _ _ -> Nothing
SUnaryOp name _ _ child -> do
val <- child
evalUnary name val
SBinaryOp name _ _ _ left right -> do
l <- left
r <- right
evalBinary name l r
joinA left right = case (left, right) of
(Just l, Just r) -> if l == r then Just l else error "inconsistent analysis"
_ -> asum [left, right]
modifyA c egr = case egr ^._class c._data of
Nothing -> egr
Just v -> let (c', egr') = represent (Fix (SLit v)) egr
in snd $ merge c c' egr'
applyNumOp :: forall a . (Typeable a, Num a) => (a -> a -> a) -> SomeVal -> SomeVal -> Maybe SomeVal
applyNumOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
Refl <- testEquality (typeRep @b) (typeRep @c)
return $ SomeVal (op x y)
applyFracOp :: forall a . (Typeable a, Fractional a) => (a -> a -> a) -> SomeVal -> SomeVal -> Maybe SomeVal
applyFracOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
Refl <- testEquality (typeRep @b) (typeRep @c)
return $ SomeVal (op x y)
applyOrdOp :: forall a . (Typeable a, Ord a) => (a -> a -> Bool) -> SomeVal -> SomeVal -> Maybe SomeVal
applyOrdOp op (SomeVal (x :: b)) (SomeVal (y :: c)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
Refl <- testEquality (typeRep @b) (typeRep @c)
return $ SomeVal (op x y)
evalUnary :: String -> SomeVal -> Maybe SomeVal
evalUnary "cos" (SomeVal (d :: t)) = do
Refl <- testEquality (typeRep @t) (typeRep @Double)
return $ SomeVal (cos d)
evalUnary "sin" (SomeVal (d :: t)) = do
Refl <- testEquality (typeRep @t) (typeRep @Double)
return $ SomeVal (sin d)
evalUnary "intToDouble" (SomeVal (i :: t)) = do
Refl <- testEquality (typeRep @t) (typeRep @Int)
return $ SomeVal (fromIntegral i :: Double)
evalUnary "intToInteger" (SomeVal (i :: t)) = do
Refl <- testEquality (typeRep @t) (typeRep @Int)
return $ SomeVal (fromIntegral i :: Integer)
evalUnary "doubleToInt" (SomeVal (d :: t)) = do
Refl <- testEquality (typeRep @t) (typeRep @Double)
return $ SomeVal (floor d :: Int)
evalUnary _ _ = Nothing
evalBinary :: String -> SomeVal -> SomeVal -> Maybe SomeVal
evalBinary "add" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) =
case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case () of
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Int) -> Just $ SomeVal (x + y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x + y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Integer) -> Just $ SomeVal (x + y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x + y)
_ -> Nothing
Nothing -> Nothing
evalBinary "mult" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) =
case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case () of
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Int) -> Just $ SomeVal (x * y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x * y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Integer) -> Just $ SomeVal (x * y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x * y)
_ -> Nothing
Nothing -> Nothing
evalBinary "divide" v1@(SomeVal (x :: a)) v2@(SomeVal (y :: b)) =
case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case () of
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Double) -> Just $ SomeVal (x / y)
_ | Just Refl <- testEquality (typeRep @a) (typeRep @Float) -> Just $ SomeVal (x / y)
_ -> Nothing
Nothing -> Nothing
evalBinary "eq" (SomeVal (x :: a)) (SomeVal (y :: b)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
return $ SomeVal (x == y)
evalBinary "gt" (SomeVal (x :: a)) (SomeVal (y :: b)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
return $ SomeVal (x > y)
evalBinary "lt" (SomeVal (x :: a)) (SomeVal (y :: b)) = do
Refl <- testEquality (typeRep @a) (typeRep @b)
return $ SomeVal (x < y)
evalBinary _ _ _ = Nothing
cost :: CostFunction SymExpr Int
cost = \case
SLit _ -> 1
SCol _ _ -> 1
SUnaryOp _ _ _ c -> 1 + c
SBinaryOp _ _ _ _ c1 c2 -> c1 + c2 + 2
pBinOp :: String -> SomeTypeRep -> SomeTypeRep -> SomeTypeRep
-> Pattern SymExpr -> Pattern SymExpr -> Pattern SymExpr
pBinOp name t1 t2 t3 p1 p2 = pat (SBinaryOp name t1 t2 t3 p1 p2)
pLitDouble :: Double -> Pattern SymExpr
pLitDouble d = pat (SLit (SomeVal d))
pLitInt :: Int -> Pattern SymExpr
pLitInt i = pat (SLit (SomeVal i))
-- Generic rewrite rules that work for any numeric type
mkNumericRewrites :: forall a anl . Typeable a => [Rewrite anl SymExpr]
mkNumericRewrites =
let t = SomeTypeRep (typeRep @a)
in [ -- (a * b) / c => a * (b / c)
pBinOp "divide" t t t
(pBinOp "mult" t t t "a" "b") "c"
:= pBinOp "mult" t t t "a"
(pBinOp "divide" t t t "b" "c")
, -- x / x => 1
pBinOp "divide" t t t "x" "x" := "one"
, -- x * 1 => x
pBinOp "mult" t t t "x" "one" := "x"
, -- 1 * x => x
pBinOp "mult" t t t "one" "x" := "x"
]
rewrites :: [Rewrite anl SymExpr]
rewrites =
let tDouble = SomeTypeRep (typeRep @Double)
tInt = SomeTypeRep (typeRep @Int)
in [
pBinOp "divide" tDouble tDouble tDouble
(pBinOp "mult" tDouble tDouble tDouble "a" "b") "c"
:= pBinOp "mult" tDouble tDouble tDouble "a"
(pBinOp "divide" tDouble tDouble tDouble "b" "c")
, pBinOp "divide" tDouble tDouble tDouble "x" "x" := pLitDouble 1.0
, pBinOp "mult" tDouble tDouble tDouble "x" (pLitDouble 1.0) := "x"
, pBinOp "mult" tDouble tDouble tDouble (pLitDouble 1.0) "x" := "x"
, pBinOp "mult" tInt tInt tInt "x" (pLitInt 1) := "x"
, pBinOp "mult" tInt tInt tInt (pLitInt 1) "x" := "x"
]
e1 :: TExpr Double
e1 = divide (mult (col "x") (lit 2)) (lit 2)
e2 :: TExpr Int
e2 = add (mult (col "count") (lit 5)) (lit 10)
e3 :: TExpr Double
e3 = add (intToDouble (col "count")) (lit 1.0)
e4 :: TExpr Bool
e4 = gtOp (intToDouble (col "x")) (lit 10.0)
e5 :: TExpr Integer
e5 = add (mult (col "bignum") (lit (100 :: Integer))) (lit (50 :: Integer))
intExample :: TExpr Int
intExample = add (mult (col "a") (lit 2)) (lit 3)
doubleExample :: TExpr Double
doubleExample = add (mult (col "a") (lit 2.0)) (lit 3.0)
main :: IO ()
main = do
putStrLn "=== Original expression e1: (x * 2.0) / 2.0 (Double) ==="
print (unTExpr e1)
putStrLn "\n=== After equality saturation ==="
let (optimized, _egraph) = equalitySaturation @(Maybe SomeVal) (unTExpr e1) rewrites cost
print optimized
putStrLn "\n=== Integer expression e2: (count * 5) + 10 ==="
print (unTExpr e2)
let (opt2, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e2) rewrites cost
print opt2
putStrLn "\n=== Heterogeneous expression e3: intToDouble(count) + 1.0 ==="
print (unTExpr e3)
putStrLn "\n=== Comparison expression e4: intToDouble(x) > 10.0 ==="
print (unTExpr e4)
putStrLn "\n=== Constant folding test: (2 * 3) + 5 (Int) ==="
let e6 = add (mult (lit @Int 2) (lit 3)) (lit 5)
let (folded, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e6) rewrites cost
print folded
putStrLn "Expected: 11"
putStrLn "\n=== Constant folding test: (2.0 * 3.0) + 5.0 (Double) ==="
let e7 = add (mult (lit @Double 2.0) (lit 3.0)) (lit 5.0)
let (folded2, _) = equalitySaturation @(Maybe SomeVal) (unTExpr e7) rewrites cost
print folded2
putStrLn "Expected: 11.0"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment