Skip to content

Instantly share code, notes, and snippets.

@mchav
Created January 8, 2026 21:05
Show Gist options
  • Select an option

  • Save mchav/1d1f14262f8b2408045256147de0d8e2 to your computer and use it in GitHub Desktop.

Select an option

Save mchav/1d1f14262f8b2408045256147de0d8e2 to your computer and use it in GitHub Desktop.
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
{-# LANGUAGE TypeApplications #-}
module EGGP where
import Algorithm.EqSat.Egraph
import Algorithm.EqSat.Simplify
import Algorithm.EqSat.Build
import Algorithm.EqSat.Queries
import Algorithm.EqSat.Info
import Algorithm.EqSat.DB
import Algorithm.SRTree.Likelihoods
import Algorithm.SRTree.ModelSelection
import Algorithm.SRTree.Opt
import Control.Exception (throw)
import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.))
import Control.Monad (foldM, forM_, forM, when, unless, filterM, (>=>), replicateM, replicateM_)
import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import qualified Data.Map.Strict as Map
import qualified DataFrame.Internal.Column as DI
import qualified DataFrame.Internal.DataFrame as DI
import qualified DataFrame as D
import qualified DataFrame.Functions as F
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as VU
import Data.Massiv.Array as MA hiding (forM_, forM)
import Data.Maybe (fromJust, isNothing, isJust)
import Data.SRTree
import Data.SRTree.Datasets
import Data.SRTree.Eval
import Data.SRTree.Random (randomTree)
import Data.SRTree.Print
import System.Random
import qualified Data.HashSet as Set
import Data.List ( sort, maximumBy, intercalate, sortOn, intersperse, nub, transpose )
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import qualified Data.Sequence as FingerTree
import Data.Function ( on )
import qualified Data.Foldable as Foldable
import qualified Data.IntMap as IntMap
import List.Shuffle ( shuffle )
import Algorithm.SRTree.NonlinearOpt
import Data.Binary ( encode, decode )
import qualified Data.ByteString.Lazy as BS
import Algorithm.EqSat (runEqSat,applySingleMergeOnlyEqSat)
import GHC.IO (unsafePerformIO)
import Control.Scheduler
import Control.Monad.IO.Unlift
import Data.SRTree (convertProtectedOps)
import Options.Applicative as Opt hiding (Const)
import Search
import Algorithm.EqSat.SearchSR
import Data.SRTree.Random
import Data.SRTree.Datasets
import qualified Data.Massiv.Array as M
import Text.Read (readMaybe)
opt :: Parser Args
opt = Args
<$> strOption
( long "dataset"
<> short 'd'
<> metavar "INPUT-FILE"
<> help "CSV dataset." )
<*> strOption
( long "test"
<> short 't'
<> value ""
<> showDefault
<> help "test data")
<*> option auto
( long "generations"
<> short 'g'
<> metavar "GENS"
<> showDefault
<> value 100
<> help "Number of generations." )
<*> option auto
( long "maxSize"
<> short 's'
<> help "max-size." )
<*> option auto
( long "folds"
<> short 'k'
<> value 1
<> showDefault
<> help "number of folds to determine the ratio of training-validation")
<*> switch
( long "trace"
<> help "print all evaluated expressions.")
<*> option auto
( long "loss"
<> value MSE
<> showDefault
<> help "loss function: MSE, Gaussian, Poisson, Bernoulli.")
<*> option auto
( long "opt-iter"
<> value 30
<> showDefault
<> help "number of iterations in parameter optimization.")
<*> option auto
( long "opt-retries"
<> value 1
<> showDefault
<> help "number of retries of parameter fitting.")
<*> option auto
( long "number-params"
<> value (-1)
<> showDefault
<> help "maximum number of parameters in the model. If this argument is absent, the number is bounded by the maximum size of the expression and there will be no repeated parameter.")
<*> option auto
( long "nPop"
<> value 100
<> showDefault
<> help "population size (Default: 100).")
<*> option auto
( long "tournament-size"
<> value 2
<> showDefault
<> help "tournament size.")
<*> option auto
( long "pc"
<> value 1.0
<> showDefault
<> help "probability of crossover.")
<*> option auto
( long "pm"
<> value 0.3
<> showDefault
<> help "probability of mutation.")
<*> strOption
( long "non-terminals"
<> value "Add,Sub,Mul,Div,PowerAbs,Recip"
<> showDefault
<> help "set of non-terminals to use in the search."
)
<*> strOption
( long "dump-to"
<> value ""
<> showDefault
<> help "dump final e-graph to a file."
)
<*> strOption
( long "load-from"
<> value ""
<> showDefault
<> help "load initial e-graph from a file."
)
<*> switch
( long "generational"
<> help "replace the current population with the children instead of keeping the pareto front."
)
<*> switch
( long "simplify"
<> help "simplify the expressions before displaying them."
)
<*> option auto
( long "max-time"
<> value (-1)
<> showDefault
<> help "maximum allowed time budget (in seconds, -1 it will run for the number of generations)"
)
<*> strOption
( long "varnames"
<> value ""
<> showDefault
<> help "comma separated variable names." )
eggp :: M.Array S Ix2 Double -> M.Array S Ix1 Double -> Int -> Int -> Int -> Double -> Double -> String -> String -> IO [Fix SRTree]
eggp x target gens maxSize nPop pc pm nonterminals varnames = do
g <- getStdGen
let args = Args "" "" gens maxSize 1 False MSE 50 2 (-1) nPop 3 pc pm nonterminals "" "" False True (-1) varnames
dataTrains' = (x, target, Nothing)
let alg = evalStateT (egraphGP [((x, target, Nothing), (x, target, Nothing))] [(x, target, Nothing)] args) emptyGraph
evalStateT alg g
eggp_df :: String -> Int -> Int -> Int -> Double -> Double -> String -> String -> IO [Fix SRTree]
eggp_df _ gens maxSize nPop pc pm nonterminals varnames = do
df <- fmap (D.selectBy [D.byProperty (D.hasElemType @Double)]) (D.readParquet "../dataframe/data/mtcars.parquet")
let x = fromLists' Seq (Data.List.transpose $ V.toList (V.map (DI.toList @Double) (DI.columns (D.exclude ["mpg"] df)))) :: Array S Ix2 Double
print x
let target = fromLists' Seq (D.columnAsList (F.col "mpg") df) :: Array S Ix1 Double
print target
g <- getStdGen
let args = Args "" "" gens maxSize 1 False MSE 50 2 (-1) nPop 3 pc pm nonterminals "" "" False True (-1) varnames
dataTrains' = (x, target, Nothing)
let alg = evalStateT (egraphGP [((x, target, Nothing), (x, target, Nothing))] [(x, target, Nothing)] args) emptyGraph
res <- evalStateT alg g
print (Prelude.map showExpr res)
pure res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment