Skip to content

Instantly share code, notes, and snippets.

@pwm
Forked from ChrisPenner/DynamicBFS.hs
Created February 22, 2023 10:04
Show Gist options
  • Save pwm/b37e8c54e24e57df218b3070a9327ece to your computer and use it in GitHub Desktop.
Save pwm/b37e8c54e24e57df218b3070a9327ece to your computer and use it in GitHub Desktop.
Effectful, lazy, BFS using LogicT.
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module BFS where
import Control.Applicative
import Control.Monad.Logic
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Data.Foldable
import Data.Functor
import Data.Functor.Identity
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import Data.Traversable
-- 1
-- / \
-- 2 3
-- / \ \
-- 4 6 5
-- / \
-- 7 8
tree :: Map Int [Int]
tree =
Map.fromList $
[ (1, [2, 3]),
(2, [4, 6]),
(4, [7, 8]),
(3, [5])
]
getChild :: Int -> [Int]
getChild k = fromMaybe [] $ Map.lookup k tree
-- |
-- >>> dagbfs (== 7) (pure . getChild) 1
dagbfs :: forall a m. (MonadIO m, Show a) => (a -> Bool) -> (a -> m [a]) -> a -> m (Maybe [a])
dagbfs goal getChildren root =
fmap listToMaybe . observeManyT 1 . flip runReaderT [] $ (bfs root >>= guardMaybe)
where
guardMaybe Nothing = empty
guardMaybe (Just a) = pure a
bfs :: a -> ReaderT [a] (LogicT m) (Maybe [a])
bfs child = local (child :) $ do
liftIO $ print child
if goal child
then Just <$> ask
else
pure Nothing <|> do
children <- lift . lift $ getChildren child
interleaveAll (bfs <$> children)
-- |
-- >>> observeAll $ interleaveAll [pure @Logic 1, pure 2 <|> pure 3 <|> pure 4, pure 5, pure 6 <|> pure 7]
-- [1,2,5,6,3,7,4]
interleaveAll :: forall f a. MonadLogic f => [f a] -> f a
interleaveAll [] = empty
interleaveAll xs = do
go xs []
where
go :: [f a] -> [f a] -> f a
go [] ms = interleaveAll ms
go (x : rest) ms = do
msplit x >>= \case
Nothing -> go rest ms
Just (a, m) -> pure a <|> go rest (ms ++ [m])
-- | >>> test
-- 1
-- 2
-- 3
-- 4
-- 5
-- 6
-- 7
-- [1, 2, 3, 4, 5, 6, 7]
test =
observeAllT $
interleaveAll
[ pp 1,
pp 2 <|> pp 5 <|> pp 7,
pp 3,
pp 4 <|> pp 6
]
where
pp :: Int -> LogicT IO Int
pp x = liftIO (print x) $> x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment