Created
January 8, 2026 21:05
-
-
Save mchav/1d1f14262f8b2408045256147de0d8e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE BlockArguments #-} | |
| {-# LANGUAGE TupleSections #-} | |
| {-# LANGUAGE MultiWayIf #-} | |
| {-# LANGUAGE OverloadedStrings #-} | |
| {-# LANGUAGE BangPatterns #-} | |
| {-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| module EGGP where | |
| import Algorithm.EqSat.Egraph | |
| import Algorithm.EqSat.Simplify | |
| import Algorithm.EqSat.Build | |
| import Algorithm.EqSat.Queries | |
| import Algorithm.EqSat.Info | |
| import Algorithm.EqSat.DB | |
| import Algorithm.SRTree.Likelihoods | |
| import Algorithm.SRTree.ModelSelection | |
| import Algorithm.SRTree.Opt | |
| import Control.Exception (throw) | |
| import Control.Lens (element, makeLenses, over, (&), (+~), (-~), (.~), (^.)) | |
| import Control.Monad (foldM, forM_, forM, when, unless, filterM, (>=>), replicateM, replicateM_) | |
| import Control.Monad.State.Strict | |
| import qualified Data.IntMap.Strict as IM | |
| import qualified Data.Map.Strict as Map | |
| import qualified DataFrame.Internal.Column as DI | |
| import qualified DataFrame.Internal.DataFrame as DI | |
| import qualified DataFrame as D | |
| import qualified DataFrame.Functions as F | |
| import qualified Data.Vector as V | |
| import qualified Data.Vector.Unboxed as VU | |
| import Data.Massiv.Array as MA hiding (forM_, forM) | |
| import Data.Maybe (fromJust, isNothing, isJust) | |
| import Data.SRTree | |
| import Data.SRTree.Datasets | |
| import Data.SRTree.Eval | |
| import Data.SRTree.Random (randomTree) | |
| import Data.SRTree.Print | |
| import System.Random | |
| import qualified Data.HashSet as Set | |
| import Data.List ( sort, maximumBy, intercalate, sortOn, intersperse, nub, transpose ) | |
| import Data.IntSet (IntSet) | |
| import qualified Data.IntSet as IntSet | |
| import qualified Data.Sequence as FingerTree | |
| import Data.Function ( on ) | |
| import qualified Data.Foldable as Foldable | |
| import qualified Data.IntMap as IntMap | |
| import List.Shuffle ( shuffle ) | |
| import Algorithm.SRTree.NonlinearOpt | |
| import Data.Binary ( encode, decode ) | |
| import qualified Data.ByteString.Lazy as BS | |
| import Algorithm.EqSat (runEqSat,applySingleMergeOnlyEqSat) | |
| import GHC.IO (unsafePerformIO) | |
| import Control.Scheduler | |
| import Control.Monad.IO.Unlift | |
| import Data.SRTree (convertProtectedOps) | |
| import Options.Applicative as Opt hiding (Const) | |
| import Search | |
| import Algorithm.EqSat.SearchSR | |
| import Data.SRTree.Random | |
| import Data.SRTree.Datasets | |
| import qualified Data.Massiv.Array as M | |
| import Text.Read (readMaybe) | |
| opt :: Parser Args | |
| opt = Args | |
| <$> strOption | |
| ( long "dataset" | |
| <> short 'd' | |
| <> metavar "INPUT-FILE" | |
| <> help "CSV dataset." ) | |
| <*> strOption | |
| ( long "test" | |
| <> short 't' | |
| <> value "" | |
| <> showDefault | |
| <> help "test data") | |
| <*> option auto | |
| ( long "generations" | |
| <> short 'g' | |
| <> metavar "GENS" | |
| <> showDefault | |
| <> value 100 | |
| <> help "Number of generations." ) | |
| <*> option auto | |
| ( long "maxSize" | |
| <> short 's' | |
| <> help "max-size." ) | |
| <*> option auto | |
| ( long "folds" | |
| <> short 'k' | |
| <> value 1 | |
| <> showDefault | |
| <> help "number of folds to determine the ratio of training-validation") | |
| <*> switch | |
| ( long "trace" | |
| <> help "print all evaluated expressions.") | |
| <*> option auto | |
| ( long "loss" | |
| <> value MSE | |
| <> showDefault | |
| <> help "loss function: MSE, Gaussian, Poisson, Bernoulli.") | |
| <*> option auto | |
| ( long "opt-iter" | |
| <> value 30 | |
| <> showDefault | |
| <> help "number of iterations in parameter optimization.") | |
| <*> option auto | |
| ( long "opt-retries" | |
| <> value 1 | |
| <> showDefault | |
| <> help "number of retries of parameter fitting.") | |
| <*> option auto | |
| ( long "number-params" | |
| <> value (-1) | |
| <> showDefault | |
| <> help "maximum number of parameters in the model. If this argument is absent, the number is bounded by the maximum size of the expression and there will be no repeated parameter.") | |
| <*> option auto | |
| ( long "nPop" | |
| <> value 100 | |
| <> showDefault | |
| <> help "population size (Default: 100).") | |
| <*> option auto | |
| ( long "tournament-size" | |
| <> value 2 | |
| <> showDefault | |
| <> help "tournament size.") | |
| <*> option auto | |
| ( long "pc" | |
| <> value 1.0 | |
| <> showDefault | |
| <> help "probability of crossover.") | |
| <*> option auto | |
| ( long "pm" | |
| <> value 0.3 | |
| <> showDefault | |
| <> help "probability of mutation.") | |
| <*> strOption | |
| ( long "non-terminals" | |
| <> value "Add,Sub,Mul,Div,PowerAbs,Recip" | |
| <> showDefault | |
| <> help "set of non-terminals to use in the search." | |
| ) | |
| <*> strOption | |
| ( long "dump-to" | |
| <> value "" | |
| <> showDefault | |
| <> help "dump final e-graph to a file." | |
| ) | |
| <*> strOption | |
| ( long "load-from" | |
| <> value "" | |
| <> showDefault | |
| <> help "load initial e-graph from a file." | |
| ) | |
| <*> switch | |
| ( long "generational" | |
| <> help "replace the current population with the children instead of keeping the pareto front." | |
| ) | |
| <*> switch | |
| ( long "simplify" | |
| <> help "simplify the expressions before displaying them." | |
| ) | |
| <*> option auto | |
| ( long "max-time" | |
| <> value (-1) | |
| <> showDefault | |
| <> help "maximum allowed time budget (in seconds, -1 it will run for the number of generations)" | |
| ) | |
| <*> strOption | |
| ( long "varnames" | |
| <> value "" | |
| <> showDefault | |
| <> help "comma separated variable names." ) | |
| eggp :: M.Array S Ix2 Double -> M.Array S Ix1 Double -> Int -> Int -> Int -> Double -> Double -> String -> String -> IO [Fix SRTree] | |
| eggp x target gens maxSize nPop pc pm nonterminals varnames = do | |
| g <- getStdGen | |
| let args = Args "" "" gens maxSize 1 False MSE 50 2 (-1) nPop 3 pc pm nonterminals "" "" False True (-1) varnames | |
| dataTrains' = (x, target, Nothing) | |
| let alg = evalStateT (egraphGP [((x, target, Nothing), (x, target, Nothing))] [(x, target, Nothing)] args) emptyGraph | |
| evalStateT alg g | |
| eggp_df :: String -> Int -> Int -> Int -> Double -> Double -> String -> String -> IO [Fix SRTree] | |
| eggp_df _ gens maxSize nPop pc pm nonterminals varnames = do | |
| df <- fmap (D.selectBy [D.byProperty (D.hasElemType @Double)]) (D.readParquet "../dataframe/data/mtcars.parquet") | |
| let x = fromLists' Seq (Data.List.transpose $ V.toList (V.map (DI.toList @Double) (DI.columns (D.exclude ["mpg"] df)))) :: Array S Ix2 Double | |
| print x | |
| let target = fromLists' Seq (D.columnAsList (F.col "mpg") df) :: Array S Ix1 Double | |
| print target | |
| g <- getStdGen | |
| let args = Args "" "" gens maxSize 1 False MSE 50 2 (-1) nPop 3 pc pm nonterminals "" "" False True (-1) varnames | |
| dataTrains' = (x, target, Nothing) | |
| let alg = evalStateT (egraphGP [((x, target, Nothing), (x, target, Nothing))] [(x, target, Nothing)] args) emptyGraph | |
| res <- evalStateT alg g | |
| print (Prelude.map showExpr res) | |
| pure res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment