Created
January 4, 2022 22:04
-
-
Save LeventErkok/4fddbac1c731575b0596c8f4799eecce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Solution to: | |
-- https://stackoverflow.com/questions/70565942/how-to-find-3-triangles-passing-through-every-dot-of-a-5x5-grid-in-z3 | |
{- | |
This prints: | |
*Main> main | |
1.... | |
11... | |
1.1.. | |
1..1. | |
11111 | |
.2... | |
.22.. | |
.2.2. | |
.2222 | |
..... | |
..333 | |
...33 | |
....3 | |
..... | |
..... | |
-} | |
{-# LANGUAGE DeriveAnyClass #-} | |
{-# LANGUAGE DeriveDataTypeable #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
import Control.Monad | |
import Data.Array hiding(inRange) | |
import Data.SBV | |
import Data.SBV.Control | |
data Dir = E | NE | N | NW | W | SW | S | SE | Bad | |
mkSymbolicEnumeration ''Dir | |
type Node = (SInteger, SInteger) | |
type Line = [(SBool, Node)] | |
type Triangle = (Line, Line, Line) | |
dir :: Node -> Node -> SDir | |
dir (x, y) second = match (x+1, y) sE | |
$ match (x+1, y+1) sNE | |
$ match (x, y+1) sN | |
$ match (x-1, y-1) sNW | |
$ match (x-1, y) sW | |
$ match (x-1, y-1) sSW | |
$ match (x, y-1) sS | |
$ match (x+1, y-1) sSE | |
sBad | |
where match x = ite (x .== second) | |
lineDirection :: Line -> SDir | |
lineDirection [] = sBad | |
lineDirection [_] = sBad | |
lineDirection ((xv, x):(yv, y):rest) = ite (xv .&& yv) (go ((yv, y):rest)) sBad | |
where direction = dir x y | |
go [] = direction | |
go [_] = direction | |
go ((av, a):(bv, b):ns) = ite (av .&& bv) (ite (direction .== dir a b) (go ((bv, b):ns)) sBad) direction | |
triangle :: Triangle -> SBool | |
triangle (l1, l2, l3) = lst l1 .== head l2 | |
.&& lst l2 .== head l3 | |
.&& lst l3 .== head l1 | |
.&& distinct [lineDirection l1, lineDirection l2, lineDirection l3, sBad] | |
where lst = find . reverse | |
find [] = error "bad find!" | |
find [p] = p | |
find (p@(x,_):rest) = ite x p (find rest) | |
valid :: Integer -> Triangle -> Triangle -> Triangle -> SBool | |
valid n t1 t2 t3 = sAll covered [(x, y) | x <- [0 .. n-1], y <- [0 .. n-1]] | |
where nodes = triangleNodes t1 ++ triangleNodes t2 ++ triangleNodes t3 | |
triangleNodes (l1, l2, l3) = l1 ++ l2 ++ l3 | |
covered (x, y) = (sTrue, (literal x, literal y)) `sElem` nodes | |
mkTriangle :: Integer -> Symbolic Triangle | |
mkTriangle n = do t <- (,,) <$> mkLine <*> mkLine <*> mkLine | |
constrain $ triangle t | |
pure t | |
where mkLine :: Symbolic Line | |
mkLine = do ns <- mapM (const mkElt) [1..n] | |
constrain $ goodLine (map fst ns) | |
pure ns | |
range = (0, literal (n - 1)) | |
mkElt = do v <- sBool_ | |
i <- sInteger_ | |
j <- sInteger_ | |
constrain $ inRange i range | |
constrain $ inRange j range | |
pure (v, (i, j)) | |
goodLine (x:y:rest) = x .&& y .&& good rest | |
where good [] = sTrue | |
good (x:xs) = ite x (good xs) (sAll sNot xs) | |
puzzle :: Integer -> IO () | |
puzzle n = runSMT $ do t1 <- mkTriangle n | |
t2 <- mkTriangle n | |
t3 <- mkTriangle n | |
constrain $ valid n t1 t2 t3 | |
let check = do cs <- checkSat | |
case cs of | |
Sat -> do let a0 = array ((0, 0), (n-1, n-1)) [((i, j), 0) | i <- [0..n-1], j <- [0..n-1]] | |
collectLine _ arr [] = pure arr | |
collectLine m arr ((b, (x, y)) : rest) = do bv <- getValue b | |
if bv then do xv <- getValue x | |
yv <- getValue y | |
collectLine m (arr // [((xv, yv), m)]) rest | |
else pure arr | |
collectTriangle m arr (l1, l2, l3) = foldM (collectLine m) arr [l1, l2, l3] | |
let dispTriangle arr = do let row y = let elt x = case arr ! (x, y) of | |
i | 1 <= i && i <= 3 -> putStr (show i) | |
0 -> putStr "." | |
m -> putStr (show m) | |
in do mapM_ elt [0 .. n-1] | |
putStrLn "" | |
io $ do mapM_ row (reverse [0 .. n-1]) | |
putStrLn "" | |
dispTriangle =<< collectTriangle 1 a0 t1 | |
dispTriangle =<< collectTriangle 2 a0 t2 | |
dispTriangle =<< collectTriangle 3 a0 t3 | |
cs -> error $ "Solver said: " ++ show cs | |
query check | |
main :: IO () | |
main = puzzle 5 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment