Last active
July 27, 2016 11:50
-
-
Save chpatrick/db02b4c64c56d7716d97661d8871ea04 to your computer and use it in GitHub Desktop.
monad-dijkstra++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE TupleSections #-} | |
module Lib where | |
import Control.Applicative | |
import Control.Monad.Identity | |
import Control.Monad.Trans.Free | |
import Control.Monad.Trans.Free.Church | |
import Control.Monad.Writer | |
import Data.Conduit | |
import Data.Foldable | |
import qualified Data.PQueue.Prio.Min as PQ | |
data SearchF c a | |
= Abandon | |
| Cost !c {- cost -} !c {- remainder estimate -} !a | |
| Branch !a !a | |
deriving Functor | |
newtype SearchT c m a = SearchT (FT (SearchF c) m a) | |
deriving (Functor, Applicative, Monad, MonadTrans, MonadIO) | |
type Search c a = SearchT c Identity a | |
cost :: c -> c -> SearchT c m () | |
cost c e = SearchT (liftF (Cost c e ())) | |
cost' :: Monoid c => c -> SearchT c m () | |
cost' c = cost c mempty | |
instance Alternative (SearchT c m) where | |
empty = SearchT (liftF Abandon) | |
SearchT left <|> SearchT right = SearchT (wrap (Branch left right)) | |
type Node c m a = FreeT (SearchF c) m a | |
data StepResults c m a = StepResults | |
{ srNextSteps :: !(Endo [ ( c, {- cost -} c {- remainder estimate -}, Node c m a ) ]) | |
, srResults :: !(Endo [ a ]) | |
} | |
instance Monoid (StepResults c m a) where | |
StepResults leftSteps leftResults `mappend` StepResults rightSteps rightResults = | |
StepResults (leftSteps <> rightSteps) (leftResults <> rightResults) | |
mempty = StepResults mempty mempty | |
findNeighbors :: Monad m => FreeT (SearchF c) m a -> m (StepResults c m a) | |
findNeighbors ft = do | |
ftf <- runFreeT ft | |
case ftf of | |
Pure x -> return $ StepResults mempty (Endo (x:)) | |
Free f -> case f of | |
Abandon -> return mempty | |
Cost c e t -> return $ StepResults (Endo (( c, e, t ):)) mempty | |
Branch left right -> mappend <$> findNeighbors left <*> findNeighbors right | |
runSearchTSource :: (Ord c, Monoid c, Monad m) => SearchT c m a -> Source m ( c, a ) | |
runSearchTSource (SearchT f) = search initialQueue | |
where | |
-- the queue is ordered by estimated final cost and contains the real cost so far | |
initialQueue = PQ.singleton mempty ( mempty, fromFT f) | |
search pq = case PQ.minView pq of | |
Nothing -> return () | |
Just ( ( nodeRunningCost, node ), costlierPQ ) -> do | |
StepResults nextsEndo resultsEndo <- lift $ findNeighbors node | |
let results = resultsEndo `appEndo` [] | |
for_ results $ \x -> yield ( nodeRunningCost, x ) | |
let nexts = nextsEndo `appEndo` [] | |
let pqWithNexts = | |
foldl' (\queue ( nextCost, nextRemainderEstimate, next ) -> | |
let nextRunningCost = nodeRunningCost <> nextCost | |
in PQ.insert (nextRunningCost <> nextRemainderEstimate) ( nextRunningCost, next ) queue) costlierPQ nexts | |
search pqWithNexts | |
-- stream results in increasing order | |
runSearch :: (Ord c, Monoid c) => Search c a -> [ ( c, a ) ] | |
runSearch = runIdentity . sourceToList . runSearchTSource | |
-------------------------------------------- | |
-- results are returned in correct order | |
test :: Monad m => SearchT (Sum Int) m Int | |
test = asum | |
[ cost' (Sum 1) *> return 1 | |
, cost' (Sum 3) *> cost' (Sum 2) *> return 5 | |
, cost' (Sum 2) *> return 2 | |
, cost' (Sum 6) *> return 6 | |
, cost' (Sum 4) *> return 4 | |
, cost' (Sum 0) *> return 0 | |
, cost' (Sum 1) *> return 1 | |
] | |
-- `runSearchTSource loopTest $$ awaitForever (print . liftIO)` prints thunk once for every solution demanded | |
loopTest :: SearchT (Sum Int) IO () | |
loopTest = (liftIO (putStrLn "thunk")) <|> (cost' (Sum 1) *> loopTest) | |
-- `runSearch loopTestPure` streams results lazily | |
loopTestPure :: Monad m => SearchT (Sum Int) m String | |
loopTestPure = (return "thunk") <|> (cost' (Sum 1) *> loopTestPure) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment