Created
February 15, 2023 09:49
-
-
Save 123jimin/f0c20d53814dc7e25de3d578f67fa579 to your computer and use it in GitHub Desktop.
Snake Puzzle Solver
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ortools.sat.python import cp_model | |
class SolverCallback(cp_model.CpSolverSolutionCallback): | |
def __init__(self, vars): | |
super().__init__() | |
self.example_solutions = list() | |
self.num_solutions = 0 | |
self.max_keep_solutions = 100 | |
self.max_show_solutions = 100 | |
self.vars = vars | |
def OnSolutionCallback(self): | |
self.num_solutions += 1 | |
if self.max_keep_solutions > 0 and len(self.example_solutions) >= self.max_keep_solutions: | |
return | |
solution_str = "\n".join(" ".join('_' if self.Value(v) == 0 else '#' for v in row) for row in self.vars.var_grid) | |
self.example_solutions.append(solution_str) | |
if self.max_show_solutions > 0 and len(self.example_solutions) > self.max_show_solutions: | |
return | |
print(f"Solution #{len(self.example_solutions)} after {self.NumBranches()} branches:") | |
print(solution_str) | |
if self.max_show_solutions > 0 and len(self.example_solutions) >= self.max_show_solutions: | |
print("(further solutions will be not shown)") | |
class Solver: | |
def __init__(self, rows, cols): | |
self.model = cp_model.CpModel() | |
self.rows = rows | |
self.cols = cols | |
def __repr__(self): | |
return f"Solver(rows={repr(self.rows)}, cols={repr(self.cols)})" | |
def _init_vars(self): | |
R = len(self.rows) | |
C = len(self.cols) | |
model = self.model | |
# Cell variables | |
var_grid = self.var_grid = list(list(model.NewBoolVar(f"C[{r},{c}]") for c in range(C)) for r in range(R)) | |
var_ind = self.var_ind = list(list(model.NewIntVar(0, R*C, f"I[{r},{c}]") for c in range(C)) for r in range(R)) | |
var_start = self.var_start = model.NewIntVar(0, R*C-1, "START") | |
var_end = self.var_end = model.NewIntVar(0, R*C-1, "END") | |
# Condition: snake | |
model.Add(var_start < var_end) | |
for r in range(R): | |
for c in range(C): | |
ind = r*C + c | |
grid_cell = var_grid[r][c] | |
grid_ind = var_ind[r][c] | |
is_start = model.NewBoolVar(f"S[${r},${c}]") | |
is_middle = model.NewBoolVar(f"M[${r},${c}]") | |
is_end = model.NewBoolVar(f"E[${r},${c}]") | |
is_none = grid_cell.Not() | |
# Exactly one of four can be true | |
model.Add(is_start + is_middle + is_end + is_none == 1) | |
# enforce grid_ind | |
model.Add(grid_ind == 0).OnlyEnforceIf(is_none) | |
model.Add(grid_ind == 1).OnlyEnforceIf(is_start) | |
model.Add(grid_ind > 1).OnlyEnforceIf(is_middle) | |
model.Add(grid_ind > 1).OnlyEnforceIf(is_end) | |
# set is_start | |
model.Add(ind == var_start).OnlyEnforceIf(is_start) | |
model.Add(ind != var_start).OnlyEnforceIf(is_start.Not()) | |
# set is_end | |
model.Add(ind == var_end).OnlyEnforceIf(is_end) | |
model.Add(ind != var_end).OnlyEnforceIf(is_end.Not()) | |
neighbor_cells = [] | |
neighbor_inds = [] | |
neighbor_prev_inds = [] | |
for (dr, dc) in ((-1, 0), (+1, 0), (0, -1), (0, +1)): | |
if 0 <= r+dr < R and 0 <= c+dc < C: | |
neighbor_cells.append(var_grid[r+dr][c+dc]) | |
neighbor_ind = var_ind[r+dr][c+dc] | |
neighbor_inds.append(neighbor_ind) | |
is_prev_ind = model.NewBoolVar("") | |
model.Add(neighbor_ind+1 == grid_ind).OnlyEnforceIf(is_prev_ind) | |
model.Add(neighbor_ind+1 != grid_ind).OnlyEnforceIf(is_prev_ind.Not()) | |
neighbor_prev_inds.append(is_prev_ind) | |
# enforce neighbor_count | |
model.Add(sum(neighbor_cells) == 1).OnlyEnforceIf(is_start) | |
model.Add(sum(neighbor_cells) == 2).OnlyEnforceIf(is_middle) | |
model.Add(sum(neighbor_cells) == 1).OnlyEnforceIf(is_end) | |
# enforce middle | |
model.Add(sum(neighbor_inds) == 2 * grid_ind).OnlyEnforceIf(is_middle) | |
model.AddBoolOr(neighbor_prev_inds).OnlyEnforceIf(is_middle) | |
model.AddBoolOr(neighbor_prev_inds).OnlyEnforceIf(is_end) | |
# Condition: sum | |
for r in range(R): | |
if self.rows[r] > 0: | |
model.Add(sum(var_grid[r]) == self.rows[r]) | |
for c in range(C): | |
if self.cols[c] > 0: | |
model.Add(sum(var_grid[r][c] for r in range(R)) == self.cols[c]) | |
def solve(self): | |
self._init_vars() | |
solver = cp_model.CpSolver() | |
callback = SolverCallback(self) | |
status = solver.SearchForAllSolutions(self.model, callback) | |
print(f"{solver.NumBooleans()} booleans, {solver.NumBranches()} branches, {solver.NumConflicts()} conflicts") | |
print(f"{callback.num_solutions} solutions (time taken: wall {solver.WallTime():.3f}s, user {solver.UserTime():.3f}s)") | |
if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE): | |
print(solver.StatusName(status)) | |
return | |
if __name__ == "__main__": | |
rows = list(map(int, input("Clues on left, from top to bottom, with 0 for empty: ").split())) | |
cols = list(map(int, input("Clues on top, from left to right, with 0 for empty: ").split())) | |
solver = Solver(rows, cols) | |
print(solver) | |
solver.solve() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment