Last active
February 6, 2018 10:42
-
-
Save abhin4v/1f3ba367179e9f06fb6ef1552bf02ba6 to your computer and use it in GitHub Desktop.
Sudoku solver in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Arrow ((&&&)) | |
import Control.Applicative ((<|>)) | |
import Control.DeepSeq (NFData(..)) | |
import Control.Monad (foldM) | |
import Control.Parallel.Strategies (withStrategy, rdeepseq, parBuffer) | |
import Data.Bits | |
import Data.Char (isDigit, digitToInt) | |
import Data.Function (on) | |
import Data.List (nub, foldl', group, sort) | |
import Data.Maybe (isJust) | |
import qualified Data.Set as S | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Mutable as MV | |
import Data.Vector ((!)) | |
import Data.Word (Word16) | |
fixM :: (Eq t, Monad m) => (t -> m t) -> t -> m t | |
fixM f x = f x >>= \x' -> if x' == x then return x else fixM f x' | |
digits :: S.Set Word16 | |
digits = S.fromList [bit 1, bit 2, bit 3, bit 4, bit 5, bit 6, bit 7, bit 8, bit 9] | |
setBits :: Word16 -> [Word16] -> Word16 | |
setBits = foldl' (.|.) | |
data Cell = Fixed Word16 | |
| Possible Word16 | |
deriving (Show, Eq) | |
instance NFData Cell where | |
rnf (Fixed w) = rnf w | |
rnf (Possible w) = rnf w | |
type Board = V.Vector Cell | |
type CellIxs = [Int] | |
isFixed :: Cell -> Bool | |
isFixed (Fixed _ ) = True | |
isFixed _ = False | |
emptyCell :: Cell | |
emptyCell = Possible allBitsSet | |
where allBitsSet = 1022 | |
readCell :: Char -> Maybe Cell | |
readCell c | |
| c == '.' = Just emptyCell | |
| isDigit c && c > '0' = Just . Fixed . bit . digitToInt $ c | |
| otherwise = Nothing | |
mkCell :: Word16 -> Cell | |
mkCell xs | |
| xs `S.member` digits = Fixed xs | |
| otherwise = Possible xs | |
readBoard :: String -> Maybe Board | |
readBoard s = if length s /= 81 | |
then Nothing | |
else V.fromList <$> mapM readCell s | |
fromXY :: (Int, Int) -> Int | |
fromXY (x, y) = x * 9 + y | |
allRowIxs, allColIxs, allBlockIxs :: [CellIxs] | |
allRowIxs = [getRow i | i <- [0..8]] | |
where getRow n = [ fromXY (n, i) | i <- [0..8] ] | |
allColIxs = [getCol i | i <- [0..8]] | |
where getCol n = [ fromXY (i, n) | i <- [0..8] ] | |
allBlockIxs = [getBlock i | i <- [0..8]] | |
where getBlock n = let (r, c) = (n `quot` 3, n `mod` 3) | |
in [ fromXY (3 * r + i, 3 * c + j) | i <- [0..2], j <- [0..2] ] | |
showBoard :: Board -> String | |
showBoard b = unlines . map (unwords . map (showCell . (b !))) $ allRowIxs | |
where | |
showCell (Fixed x) = show $ countTrailingZeros x | |
showCell _ = "." | |
replaceCell :: Int -> Cell -> Board -> Board | |
replaceCell i c = V.modify (\v -> MV.write v i c) | |
pruneCellFixeds :: [Cell] -> Cell -> Maybe (Bool, Cell) | |
pruneCellFixeds row cell = case cell of | |
Possible xs | xs `elem` posFixeds -> Just (False, cell) | |
Possible xs | diff xs == 0 -> Nothing | |
Possible xs -> Just (xs /= diff xs, mkCell (diff xs)) | |
_ -> Just (False, cell) | |
where | |
fixeds = setBits zeroBits [x | Fixed x <- row] | |
posFixeds = | |
map fst | |
. filter (\(vs, l) -> popCount vs == l) | |
. map (head &&& length) | |
. group | |
. sort | |
$ [xs | Possible xs <- row, popCount xs < 4] | |
fixedVals = setBits fixeds posFixeds | |
diff xs = xs .&. complement fixedVals | |
pruneCellUniqs :: [Cell] -> Cell -> Maybe (Bool, Cell) | |
pruneCellUniqs row cell = case cell of | |
Possible xs | popCount (xs .&. uniqs) == 1 -> Just (True, Fixed (xs .&. uniqs)) | |
_ -> Just (False, cell) | |
where | |
uniqs = foldl' setBit zeroBits | |
. filter (\n -> (== 1) . length . take 2 $ [() | Possible xs <- row, testBit xs n]) | |
$ [1..9] | |
pruneCells :: Board -> CellIxs -> Maybe Board | |
pruneCells b cellIxs = | |
foldl' (\b' (i, changed, c) -> if changed then replaceCell i c b' else b') b | |
<$> mapM ((\(i, c) -> pruneCell c >>= \(ch, c') -> return (i, ch, c')) . (\i -> (i, b ! i))) cellIxs | |
where | |
row = map (b !) cellIxs | |
pruneCell cell = do | |
(changed, cell') <- pruneCellFixeds row cell | |
(changed', cell'') <- pruneCellUniqs row cell' | |
return (changed || changed', cell'') | |
pruneBoard :: Board -> Maybe Board | |
pruneBoard = fixM $ \board -> | |
foldM pruneCells board allRowIxs | |
>>= flip (foldM pruneCells) allColIxs | |
>>= flip (foldM pruneCells) allBlockIxs | |
isInvalidBoard :: Board -> Bool | |
isInvalidBoard b = | |
any isInvalidRow allRowIxs | |
|| any isInvalidRow allColIxs | |
|| any isInvalidRow allBlockIxs | |
where | |
isInvalidRow = not . isValidRow | |
isValidRow rowIxs = | |
let fixeds = [x | Fixed x <- map (b !) rowIxs] | |
emptyPossibles = [x | Possible x <- map (b !) rowIxs, x == 0] | |
in length fixeds == length (nub fixeds) && null emptyPossibles | |
isFinishedBoard :: Board -> Bool | |
isFinishedBoard = (== 81) . V.length . V.filter isFixed | |
solve :: Board -> Maybe Board | |
solve b | |
| isInvalidBoard b = Nothing | |
| isFinishedBoard b = Just b | |
| otherwise = | |
let (b1, b2) = splitBoard | |
in case pruneBoard b1 of | |
Nothing -> pruneBoard b2 >>= solve | |
Just b' -> solve b' <|> (pruneBoard b2 >>= solve) | |
where | |
splitBoard = | |
let (i, first, rest) = splitPossible smallestPossible | |
in (replaceCell i first b, replaceCell i rest b) | |
smallestPossible = | |
V.minimumBy (compare `on` (possibleValCount . snd)) . V.filter (not . isFixed . snd) . V.indexed $ b | |
possibleValCount ~(Possible xs) = popCount xs | |
splitPossible ~(i, Possible vs) = | |
let x = countTrailingZeros vs in (i, Fixed (bit x), mkCell (clearBit vs x)) | |
main :: IO () | |
main = do | |
boards <- lines <$> getContents | |
let solutions = parMap readAndSolve boards | |
putStrLn $ show (length $ filter isJust solutions) ++ "/" ++ show (length boards) ++ " solved" | |
where | |
readAndSolve board = case readBoard board of | |
Nothing -> Nothing | |
Just b -> pruneBoard b >>= solve | |
chunkSize = 1000 | |
parMap f = withStrategy (parBuffer chunkSize rdeepseq) . map f | |
-- example run: | |
-- $ stack install | |
-- $ echo "......52..8.4......3...9...5.1...6..2..7........3.....6...1..........7.4.......3." | sudoku |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment