Skip to content

Instantly share code, notes, and snippets.

@mchav
Created December 21, 2025 03:45
Show Gist options
  • Select an option

  • Save mchav/f60487db99282fc96449a57dd5e67be7 to your computer and use it in GitHub Desktop.

Select an option

Save mchav/f60487db99282fc96449a57dd5e67be7 to your computer and use it in GitHub Desktop.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- Fix
{-# LANGUAGE LambdaCase #-}
-- Fix
{-# LANGUAGE MultiParamTypeClasses #-}
module Main where
import Control.Applicative
import Data.Equality.Saturation
-- Fix this
import Data.Equality.Analysis
-- Fix this
import Data.Equality.Graph
import Data.Equality.Graph.Lens ((^.), _class, _data)
import Data.Equality.Matching
import qualified DataFrame as D
import DataFrame.Synthesis
import DataFrame.Internal.Column
import DataFrame.Internal.Expression
import qualified Data.Text as T
import Data.Type.Equality
import Type.Reflection
-- We need a second type: a symbolic expression that we can
-- use egraphs on. This should be polymorphic in its type
-- and we should be able to convert to and from Expr.
-- Our current Expr isn't because it is a GADT with
-- a bunch of columnable constraints.
-- We can't make it as general as Expr. It seems like we have to enumerate
-- the cases and supported functions.
--
-- Actually we can we just have to litter the code with reflection again.
--
-- I'm not sure how you'd deal with functions from a -> b (as opposed to a -> a)
-- The type gets pretty complicated.
-- So we can be content with doubles now and generalize later.
data SymExpr b a = SLit b
| SCol String
| SUnaryOp String a
| SBinaryOp String a a
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
e1 :: Fix (SymExpr Double)
e1 = Fix (SBinaryOp "divide" (Fix (SBinaryOp "mult" (Fix (SCol "x")) (Fix (SLit 2)))) (Fix (SLit 2)))
e1' :: Expr Double
e1' = BinaryOp "divide" (/) (BinaryOp "mult" (*) (Col "x") (Lit 2)) (Lit 2)
toExpr :: forall a . Columnable a => Fix (SymExpr a) -> Expr a
toExpr (Fix (SLit value)) = Lit value
toExpr (Fix (SUnaryOp name value)) = case name of
"cos" -> case testEquality (typeRep @a) (typeRep @Double) of
Just Refl -> UnaryOp (T.pack name) cos (toExpr value)
Nothing -> error "type mismatch"
_ -> error "UNIMPLEMENTED"
toExpr (Fix (SBinaryOp name left right)) = case name of
"add" -> case testEquality (typeRep @a) (typeRep @Double) of
Just Refl -> BinaryOp (T.pack name) (+) (toExpr left) (toExpr right)
Nothing -> error "type mismatch"
"divide" -> case testEquality (typeRep @a) (typeRep @Double) of
Just Refl -> BinaryOp (T.pack name) (/) (toExpr left) (toExpr right)
Nothing -> error "type mismatch"
"mult" -> case testEquality (typeRep @a) (typeRep @Double) of
Just Refl -> BinaryOp (T.pack name) (*) (toExpr left) (toExpr right)
Nothing -> error "type mismatch"
_ -> error "UNIMPLEMENTED"
toExpr (Fix (SCol name)) = Col (T.pack name)
toSymExpr :: forall a . Expr a -> Fix (SymExpr a)
toSymExpr (Lit value) = Fix (SLit value)
toSymExpr (Col value) = Fix (SCol (T.unpack value))
toSymExpr (UnaryOp name _ (value :: Expr b)) = case name of
"cos" -> case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> Fix (SUnaryOp (T.unpack name) (toSymExpr value))
Nothing -> error "type mismatch"
toSymExpr (BinaryOp name _ (left :: Expr b) (right :: Expr c)) = case name of
"add" -> case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case testEquality (typeRep @a) (typeRep @c) of
Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right))
Nothing -> error "type mismatch"
Nothing -> error "type mismatch"
"divide" -> case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case testEquality (typeRep @a) (typeRep @c) of
Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right))
Nothing -> error "type mismatch"
Nothing -> error "type mismatch"
"mult" -> case testEquality (typeRep @a) (typeRep @b) of
Just Refl -> case testEquality (typeRep @a) (typeRep @c) of
Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right))
Nothing -> error "type mismatch"
Nothing -> error "type mismatch"
_ -> error "UNIMPLEMENTED"
toSymExpr _ = error "UNIMPLEMENTED"
-- fix this
instance Analysis (Maybe Double) (SymExpr Double) where
makeA = \case
SLit x -> Just x
SCol _ -> Nothing
SBinaryOp "add" x y -> (+) <$> x <*> y
SBinaryOp "divide" x y -> (/) <$> x <*> y
SBinaryOp "mult" x y -> (*) <$> x <*> y
joinA left right = case (left, right) of
(Just l, Just r) -> if l == r then Just l else error "weird"
_ -> asum [left, right]
modifyA c egr = case egr ^._class c._data of
Nothing -> egr
Just i -> let
(c', egr') = represent (Fix (SLit i)) egr
in snd $ merge c c' egr'
cost :: CostFunction (SymExpr Double) Int
cost = \case
SLit _ -> 1
SCol _ -> 1
SUnaryOp _ c1 -> 1 + c1
SBinaryOp _ c1 c2 -> c1 + c2 + 2
-- Fix
rewrites :: [Rewrite anl (SymExpr Double)]
rewrites =
[ pat (SBinaryOp "divide" (pat (SBinaryOp "mult" "a" "b")) "c") := pat (SBinaryOp "mult" "a" (pat (SBinaryOp "divide" "b" "c")))
, pat (SBinaryOp "divide" "x" "x") := pat (SLit 1)
, pat (SBinaryOp "mult" "x" (pat (SLit 1))) := "x"
]
main :: IO ()
main = do
print (toExpr e1)
-- Fix
print (toExpr (fst (equalitySaturation @(Maybe Double) (toSymExpr e1') rewrites cost)))
print (toExpr (toSymExpr (UnaryOp "cos" cos (Lit (10 :: Double)))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment