Created
November 5, 2022 17:32
-
-
Save ejconlon/d5ac58666d1c9fcaa6ced56f19f87f0f to your computer and use it in GitHub Desktop.
Backtracking state with LogicT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | 'LogicT' is a great monad transformer for backtracking control, | |
-- but if you just layer with a 'State' monad, you won't backtrack state. | |
-- By that I mean at all choice points '(<|>)' or 'interleave', we will | |
-- save part of the state and reset it when retrying alternative branches. | |
module BacktrackingStateSearch | |
( TrackSt (..) | |
, Track | |
, observeManyTrack | |
, runManyTrack | |
) where | |
import Control.Applicative (Alternative (..)) | |
import Control.Monad.Logic (LogicT, MonadLogic (..), observeManyT) | |
import Control.Monad.State.Strict (MonadState (..), State, gets, modify', runState) | |
import Data.Bifunctor (second) | |
-- | Backtracking state - the x component goes forward, the y component backtracks | |
-- All mentions of state below are really about the backtracking state component. | |
-- The forward state component is pretty boring. | |
data TrackSt x y = TrackSt | |
{ tsFwd :: !x | |
, tsBwd :: !y | |
} deriving stock (Eq, Show) | |
-- | Backtracking search monad. Take care not to expose the constructor! | |
-- The major issue with backtracking is that the final state is that of | |
-- the last branch that has executed. In order for the 'msplit' law to hold | |
-- (`msplit m >>= reflect = m`) we have to ensure that the same state | |
-- is observable on all exit points. Basically the only way to do this is to | |
-- not make the state visible at all externally, which requires that we | |
-- protect the constructor here and only allow elimination of this type | |
-- with 'observeManyTrack', which resets the state for us. | |
newtype Track x y a = Track { unTrack :: LogicT (State (TrackSt x y)) a } | |
deriving newtype (Functor, Applicative, Monad, MonadState (TrackSt x y)) | |
-- | Wraps logict's 'observeManyT' and forces us to 'reset' the backtracking state. | |
observeManyTrack :: Int -> Track x y a -> State (TrackSt x y) [a] | |
observeManyTrack n = observeManyT n . unTrack . reset | |
-- | A nicer way to run the search. | |
runManyTrack :: Int -> Track x y a -> TrackSt x y -> ([a], TrackSt x y) | |
runManyTrack n m = runState (observeManyTrack n m) | |
-- | At many points below we'll need to restore a saved state before | |
-- continuing the search. | |
restore :: y -> Track x y a -> Track x y a | |
restore saved x = modify' (\st -> st { tsBwd = saved }) *> x | |
-- | Restores the backtracked state after all results have been enumerated. | |
finalize :: y -> Track x y a -> Track x y a | |
finalize saved x = Track (unTrack x <|> unTrack (restore saved empty)) | |
-- | Ensures the backtrack state is returned to the current state. | |
-- This is run on the outside of the search so the backtracked state is | |
-- not externally observable. | |
reset :: Track x y a -> Track x y a | |
reset x = do | |
saved <- gets tsBwd | |
finalize saved x | |
instance Alternative (Track x y) where | |
empty = Track empty | |
x <|> y = do | |
saved <- gets tsBwd | |
-- Restore the current state before going down the right branch. | |
Track (unTrack x <|> unTrack (restore saved y)) | |
instance MonadLogic (Track x y) where | |
-- This is just newtype noise - we have to define this, but we really | |
-- need to override 'interleave'. (Unless I missed a case, I don't think | |
-- we need to reset in the tail...) | |
msplit x = Track (fmap (fmap (second Track)) (msplit (unTrack x))) | |
interleave x y = do | |
saved <- gets tsBwd | |
-- Again restore the current state before going down the right branch. | |
Track (interleave (unTrack x) (unTrack (restore saved y))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment