Created
December 21, 2025 03:45
-
-
Save mchav/f60487db99282fc96449a57dd5e67be7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE OverloadedStrings #-} | |
| {-# LANGUAGE TemplateHaskell #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| {-# LANGUAGE DeriveFunctor #-} | |
| {-# LANGUAGE DeriveFoldable #-} | |
| {-# LANGUAGE DeriveTraversable #-} | |
| {-# LANGUAGE FlexibleContexts #-} | |
| {-# LANGUAGE ExplicitNamespaces #-} | |
| {-# LANGUAGE FlexibleInstances #-} | |
| {-# LANGUAGE GADTs #-} | |
| {-# LANGUAGE ScopedTypeVariables #-} | |
| -- Fix | |
| {-# LANGUAGE LambdaCase #-} | |
| -- Fix | |
| {-# LANGUAGE MultiParamTypeClasses #-} | |
| module Main where | |
| import Control.Applicative | |
| import Data.Equality.Saturation | |
| -- Fix this | |
| import Data.Equality.Analysis | |
| -- Fix this | |
| import Data.Equality.Graph | |
| import Data.Equality.Graph.Lens ((^.), _class, _data) | |
| import Data.Equality.Matching | |
| import qualified DataFrame as D | |
| import DataFrame.Synthesis | |
| import DataFrame.Internal.Column | |
| import DataFrame.Internal.Expression | |
| import qualified Data.Text as T | |
| import Data.Type.Equality | |
| import Type.Reflection | |
| -- We need a second type: a symbolic expression that we can | |
| -- use egraphs on. This should be polymorphic in its type | |
| -- and we should be able to convert to and from Expr. | |
| -- Our current Expr isn't because it is a GADT with | |
| -- a bunch of columnable constraints. | |
| -- We can't make it as general as Expr. It seems like we have to enumerate | |
| -- the cases and supported functions. | |
| -- | |
| -- Actually we can we just have to litter the code with reflection again. | |
| -- | |
| -- I'm not sure how you'd deal with functions from a -> b (as opposed to a -> a) | |
| -- The type gets pretty complicated. | |
| -- So we can be content with doubles now and generalize later. | |
| data SymExpr b a = SLit b | |
| | SCol String | |
| | SUnaryOp String a | |
| | SBinaryOp String a a | |
| deriving (Eq, Ord, Show, Functor, Foldable, Traversable) | |
| e1 :: Fix (SymExpr Double) | |
| e1 = Fix (SBinaryOp "divide" (Fix (SBinaryOp "mult" (Fix (SCol "x")) (Fix (SLit 2)))) (Fix (SLit 2))) | |
| e1' :: Expr Double | |
| e1' = BinaryOp "divide" (/) (BinaryOp "mult" (*) (Col "x") (Lit 2)) (Lit 2) | |
| toExpr :: forall a . Columnable a => Fix (SymExpr a) -> Expr a | |
| toExpr (Fix (SLit value)) = Lit value | |
| toExpr (Fix (SUnaryOp name value)) = case name of | |
| "cos" -> case testEquality (typeRep @a) (typeRep @Double) of | |
| Just Refl -> UnaryOp (T.pack name) cos (toExpr value) | |
| Nothing -> error "type mismatch" | |
| _ -> error "UNIMPLEMENTED" | |
| toExpr (Fix (SBinaryOp name left right)) = case name of | |
| "add" -> case testEquality (typeRep @a) (typeRep @Double) of | |
| Just Refl -> BinaryOp (T.pack name) (+) (toExpr left) (toExpr right) | |
| Nothing -> error "type mismatch" | |
| "divide" -> case testEquality (typeRep @a) (typeRep @Double) of | |
| Just Refl -> BinaryOp (T.pack name) (/) (toExpr left) (toExpr right) | |
| Nothing -> error "type mismatch" | |
| "mult" -> case testEquality (typeRep @a) (typeRep @Double) of | |
| Just Refl -> BinaryOp (T.pack name) (*) (toExpr left) (toExpr right) | |
| Nothing -> error "type mismatch" | |
| _ -> error "UNIMPLEMENTED" | |
| toExpr (Fix (SCol name)) = Col (T.pack name) | |
| toSymExpr :: forall a . Expr a -> Fix (SymExpr a) | |
| toSymExpr (Lit value) = Fix (SLit value) | |
| toSymExpr (Col value) = Fix (SCol (T.unpack value)) | |
| toSymExpr (UnaryOp name _ (value :: Expr b)) = case name of | |
| "cos" -> case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> Fix (SUnaryOp (T.unpack name) (toSymExpr value)) | |
| Nothing -> error "type mismatch" | |
| toSymExpr (BinaryOp name _ (left :: Expr b) (right :: Expr c)) = case name of | |
| "add" -> case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case testEquality (typeRep @a) (typeRep @c) of | |
| Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right)) | |
| Nothing -> error "type mismatch" | |
| Nothing -> error "type mismatch" | |
| "divide" -> case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case testEquality (typeRep @a) (typeRep @c) of | |
| Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right)) | |
| Nothing -> error "type mismatch" | |
| Nothing -> error "type mismatch" | |
| "mult" -> case testEquality (typeRep @a) (typeRep @b) of | |
| Just Refl -> case testEquality (typeRep @a) (typeRep @c) of | |
| Just Refl -> Fix (SBinaryOp (T.unpack name) (toSymExpr left) (toSymExpr right)) | |
| Nothing -> error "type mismatch" | |
| Nothing -> error "type mismatch" | |
| _ -> error "UNIMPLEMENTED" | |
| toSymExpr _ = error "UNIMPLEMENTED" | |
| -- fix this | |
| instance Analysis (Maybe Double) (SymExpr Double) where | |
| makeA = \case | |
| SLit x -> Just x | |
| SCol _ -> Nothing | |
| SBinaryOp "add" x y -> (+) <$> x <*> y | |
| SBinaryOp "divide" x y -> (/) <$> x <*> y | |
| SBinaryOp "mult" x y -> (*) <$> x <*> y | |
| joinA left right = case (left, right) of | |
| (Just l, Just r) -> if l == r then Just l else error "weird" | |
| _ -> asum [left, right] | |
| modifyA c egr = case egr ^._class c._data of | |
| Nothing -> egr | |
| Just i -> let | |
| (c', egr') = represent (Fix (SLit i)) egr | |
| in snd $ merge c c' egr' | |
| cost :: CostFunction (SymExpr Double) Int | |
| cost = \case | |
| SLit _ -> 1 | |
| SCol _ -> 1 | |
| SUnaryOp _ c1 -> 1 + c1 | |
| SBinaryOp _ c1 c2 -> c1 + c2 + 2 | |
| -- Fix | |
| rewrites :: [Rewrite anl (SymExpr Double)] | |
| rewrites = | |
| [ pat (SBinaryOp "divide" (pat (SBinaryOp "mult" "a" "b")) "c") := pat (SBinaryOp "mult" "a" (pat (SBinaryOp "divide" "b" "c"))) | |
| , pat (SBinaryOp "divide" "x" "x") := pat (SLit 1) | |
| , pat (SBinaryOp "mult" "x" (pat (SLit 1))) := "x" | |
| ] | |
| main :: IO () | |
| main = do | |
| print (toExpr e1) | |
| -- Fix | |
| print (toExpr (fst (equalitySaturation @(Maybe Double) (toSymExpr e1') rewrites cost))) | |
| print (toExpr (toSymExpr (UnaryOp "cos" cos (Lit (10 :: Double))))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment