Created
October 13, 2015 14:26
-
-
Save LukaHorvat/c13f73f4c91bacf5072b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE MultiWayIf, MultiParamTypeClasses, FunctionalDependencies #-} | |
module AStar where | |
import Control.Monad.State | |
import Data.Set (Set) | |
import qualified Data.Set as Set | |
import PriorityQueue (PriorityQueue) | |
import qualified PriorityQueue as Prio | |
class Ord s => AStarState s m | s -> m where | |
branch :: s -> [(m, s)] | |
valid :: s -> Bool | |
goal :: s -> Bool | |
dist :: s -> Double | |
data SearchTree s m = SearchNode s [m] [SearchTree s m] | |
noHeuristic :: s -> () | |
noHeuristic = const () | |
fromInitialWithPath :: AStarState s m => [m] -> s -> SearchTree s m | |
fromInitialWithPath path initial = SearchNode initial path (map fromMove validBranches) | |
where validBranches = filter (valid . snd) $ branch initial | |
fromMove (m, s) = fromInitialWithPath (m : path) s | |
fromInitial :: AStarState s m => s -> SearchTree s m | |
fromInitial = fromInitialWithPath [] | |
data TraverseState s m = TraverseState { visited :: Set s | |
, queue :: PriorityQueue Double (SearchTree s m) } | |
nodePriority :: AStarState s m => SearchTree s m -> Double | |
nodePriority (SearchNode s path _) = dist s + fromIntegral (length path) / 2 + 1 | |
solve :: AStarState s m => s -> Maybe [m] | |
solve initial = evalState search (TraverseState Set.empty initialNode) | |
where initialNode = Prio.singleton nodePriority (fromInitial initial) | |
search :: AStarState s m => State (TraverseState s m) (Maybe [m]) | |
search = do | |
q <- gets queue | |
if Prio.null q then return Nothing | |
else tryNext | |
tryNext :: AStarState s m => State (TraverseState s m) (Maybe [m]) | |
tryNext = do | |
q <- gets queue | |
vis <- gets visited | |
let ((_, SearchNode current path branches), q') = Prio.minView q | |
modify (\s -> s { queue = q' }) | |
if | goal current -> return (Just path) | |
| Set.notMember current vis -> do | |
modify (\s -> s { visited = Set.insert current vis }) | |
forM_ branches $ \b -> | |
modify $ \s -> s { queue = Prio.insert b (queue s) } | |
search | |
| otherwise -> search |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module PriorityQueue where | |
import Data.Map (Map) | |
import qualified Data.Map as Map | |
data PriorityQueue k v = PriorityQueue (v -> k) Int (Map (k, Int) v) | |
instance (Ord k, Show k, Show v) => Show (PriorityQueue k v) where | |
show (PriorityQueue _ _ m) = "PriorityQueue\n" ++ table | |
where (table, _) = Map.mapAccumWithKey acc "" m | |
acc str (k, _) v = (str ++ "\t" ++ show v ++ " -> " ++ show k ++ "\n", v) | |
empty :: Ord k => (v -> k) -> PriorityQueue k v | |
empty k = PriorityQueue k 0 Map.empty | |
insert :: Ord k => v -> PriorityQueue k v -> PriorityQueue k v | |
insert v (PriorityQueue k c m) = PriorityQueue k (c + 1) (Map.insert (k v, c) v m) | |
findMin :: Ord k => PriorityQueue k v -> (k, v) | |
findMin (PriorityQueue _ _ m) = let ((k, _), v) = Map.findMin m in (k, v) | |
deleteMin :: Ord k => PriorityQueue k v -> PriorityQueue k v | |
deleteMin (PriorityQueue k c m) = PriorityQueue k c (Map.deleteMin m) | |
minView :: Ord k => PriorityQueue k v -> ((k, v), PriorityQueue k v) | |
minView p = (findMin p, deleteMin p) | |
size :: Ord k => PriorityQueue k v -> Int | |
size (PriorityQueue _ _ m) = Map.size m | |
null :: Ord k => PriorityQueue k v -> Bool | |
null (PriorityQueue _ _ m) = Map.null m | |
singleton :: Ord k => (v -> k) -> v -> PriorityQueue k v | |
singleton k v = insert v $ empty k |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-} | |
module SliderPuzzle where | |
import Prelude hiding (Either(..)) | |
import Data.Vector (Vector, (//), (!?)) | |
import qualified Data.Vector as Vec | |
import Control.Arrow (second, (***)) | |
import AStar | |
boardWidth, boardHeight :: Int | |
(boardWidth, boardHeight) = (4, 4) | |
data Move = Left | Right | Up | Down deriving (Eq, Ord, Show, Read) | |
newtype Board = Board (Vector Int) deriving (Eq, Ord, Show, Read) | |
idx :: (Int, Int) -> Int | |
idx (x, y) = y * boardWidth + x | |
pos :: Int -> (Int, Int) | |
pos i = (m, d) | |
where (d, m) = i `divMod` boardWidth | |
getAt :: (Int, Int) -> Vector Int -> Maybe Int | |
getAt a@(x, y) vec | x < 0 || y < 0 || x >= boardWidth || y >= boardHeight = Nothing | |
| otherwise = vec !? idx a | |
swap :: (Int, Int) -> (Int, Int) -> Vector Int -> Maybe (Vector Int) | |
swap a b vec = do | |
x <- getAt a vec | |
y <- getAt b vec | |
return $! vec // [(idx a, y), (idx b, x)] | |
manh :: (Int, Int) -> (Int, Int) -> Double | |
manh (x, y) (z, w) = fromIntegral $ abs (x - z) + abs (y - w) | |
solution :: Board | |
solution = Board $ Vec.fromList $ [1..boardWidth * boardHeight - 1] ++ [0] | |
instance AStarState Board Move where | |
branch (Board vec) = [(m, Board v) | (m, Just v) <- swaps] | |
where (Just zero) = Vec.findIndex (== 0) vec | |
(zx, zy) = pos zero | |
neighborPos = map (second $ (+ zx) *** (+ zy)) | |
[(Left, (1, 0)), (Up, (0, 1)), (Right, (-1, 0)), (Down, (0, -1))] | |
swaps = map (second $ \p -> swap (zx, zy) p vec) neighborPos | |
valid _ = True | |
dist (Board vec) = Vec.sum . Vec.imap penalty $ vec | |
where penalty i n = pos i `manh` if n == 0 | |
then (boardWidth - 1, boardHeight - 1) | |
else pos (n - 1) | |
goal = (== solution) | |
main :: IO () | |
main = print . solve . Board . Vec.fromList $ [11,6,8,7,5,15,4,3,9,2,12,13,1,14,10,0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment