Created
November 21, 2024 04:53
-
-
Save theY4Kman/7b83a1eac24bc01bc75ee1a7a97f338f to your computer and use it in GitHub Desktop.
Original, uncleaned, log-spewing solution for Codewars Minesweeper kata: https://www.codewars.com/kata/57ff9d3b8f7dda23130015fa/train/python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from typing import Iterable, Self | |
from preloaded import open | |
@dataclass | |
class Board: | |
w: int | |
h: int | |
total_mines: int | |
cells: list[list[Cell]] | |
mines_remaining: int = 0 | |
@classmethod | |
def from_map(cls, map: str, total_mines: int) -> Self: | |
lines = map.splitlines() | |
cells = [ | |
[Cell(x, y, val) for x, val in enumerate(line.split())] | |
for y, line in enumerate(lines) | |
] | |
w = len(cells[0]) | |
h = len(cells) | |
return cls(w, h, total_mines, cells) | |
def __post_init__(self): | |
self.mines_remaining = self.total_mines | |
def __str__(self) -> str: | |
return '\n'.join(' '.join(map(str, row)) for row in self.cells) | |
#XXX################################################################# | |
def debug(self) -> str: | |
y_axis_w = len(str(self.h)) | |
x_axis_prefix = f'{"":>{y_axis_w}} ' | |
x_axis_w = len(str(self.w)) | |
x_axis_ticks = [f'{i:>{x_axis_w}}' for i in range(self.w)] | |
lines = [ | |
x_axis_prefix + ' '.join(tick[i] for tick in x_axis_ticks) | |
for i in range(x_axis_w) | |
] | |
lines.append(' ' + '-' * (len(lines[0]) - 3)) | |
lines.extend( | |
f'{n:>{y_axis_w}} | ' + line | |
for n, line in enumerate(str(self).splitlines()) | |
) | |
return '\n'.join(lines) | |
#XXX################################################################# | |
def neighbors(self, cell: Cell) -> Iterable[Cell]: | |
for dx, dy in [ | |
(-1, -1), (0, -1), (1, -1), | |
(-1, 0), (1, 0), | |
(-1, 1), (0, 1), (1, 1) | |
]: | |
x, y = cell.x + dx, cell.y + dy | |
if 0 <= x < self.w and 0 <= y < self.h: | |
yield self.cells[y][x] | |
@dataclass(order=True) | |
class Cell: | |
x: int | |
y: int | |
val: str | |
##XXX############################################## | |
def __repr__(self) -> str: | |
return f'({self.x:>2},{self.y:>2} {self.val!r})' | |
##XXX############################################## | |
def __str__(self) -> str: | |
return self.val | |
def __hash__(self) -> int: | |
return hash((self.x, self.y)) | |
@property | |
def number(self) -> int | None: | |
try: | |
return int(self.val) | |
except ValueError: | |
pass | |
@number.setter | |
def number(self, val: int) -> None: | |
self.val = str(val) | |
@property | |
def is_flagged(self) -> bool: | |
return self.val == 'x' | |
def flag(self) -> None: | |
self.val = 'x' | |
@dataclass | |
class Observation: | |
cells: set[Cell] | |
num_mines: int | |
origin: Cell | None = None | |
def __eq__(self, other) -> bool: | |
return (self.cells, self.num_mines) == (other.cells, other.num_mines) | |
def __hash__(self) -> int: | |
return hash(id(self)) | |
def __repr__(self): | |
cells = sorted(self.cells) | |
return f'{self.__class__.__name__}[{self.num_mines}]{{{", ".join(map(repr, cells))}}}' | |
@dataclass | |
class Solver: | |
board: Board | |
observations: set[Observation] = field(default_factory=set) | |
obs_by_cell: defaultdict[Cell, set[Observation]] = field(default_factory=lambda: defaultdict(set)) | |
changed_cells: set[Cell] = field(default_factory=set) | |
def __post_init__(self): | |
self.init_observations() | |
def init_observations(self) -> None: | |
for row in self.board.cells: | |
for cell in row: | |
if cell.number is not None: | |
self.add_cell_obs(cell, origin=cell) | |
def add_cell_obs(self, cell: Cell, origin: Cell | None = None) -> Observation: | |
obs = Observation(cells=set(), num_mines=cell.number, origin=origin) | |
for neighbor in self.board.neighbors(cell): | |
if neighbor.number is not None: | |
continue | |
if neighbor.is_flagged: | |
obs.num_mines -= 1 | |
else: | |
obs.cells.add(neighbor) | |
if not obs.cells: | |
return | |
self.observations.add(obs) | |
for cell in obs.cells: | |
self.obs_by_cell[cell].add(obs) | |
def add_observation(self, obs: Observation) -> Observation | None: | |
if any( | |
obs.cells == ov_obs.cells | |
for cell in obs.cells | |
for ov_obs in self.obs_by_cell[cell] | |
): | |
return None | |
for cell in obs.cells: | |
self.obs_by_cell[cell].add(obs) | |
self.observations.add(obs) | |
return obs | |
def remove_observation(self, obs: Observation) -> None: | |
self.observations.discard(obs) | |
for cell in obs.cells: | |
self.obs_by_cell[cell].discard(obs) | |
def apply_cell_changes(self): | |
changed_cells = tuple(self.changed_cells) | |
self.changed_cells.clear() | |
for cell in changed_cells: | |
for obs in tuple(self.obs_by_cell[cell]): | |
if not obs.num_mines or not obs.cells: | |
self.obs_by_cell[cell].discard(obs) | |
elif cell.is_flagged or cell.number is not None: | |
obs.cells.discard(cell) | |
if cell.is_flagged: | |
obs.num_mines -= 1 | |
if not obs.cells: | |
self.remove_observation(obs) | |
if cell.number is not None: | |
self.add_cell_obs(cell, origin=cell) | |
elif cell.is_flagged: | |
self.board.mines_remaining -= 1 | |
def clear_inferred_observations(self) -> None: | |
for obs in tuple(self.observations): | |
if not obs.cells or not obs.origin: | |
self.remove_observation(obs) | |
def simplify_observations(self) -> None: | |
for obs in tuple(self.observations): | |
if not obs.cells: | |
self.remove_observation(obs) | |
continue | |
overlapping = { | |
ov_obs | |
for cell in obs.cells | |
for ov_obs in self.obs_by_cell[cell] | |
if ov_obs is not obs | |
} | |
for ov_obs in overlapping: | |
if ov_obs.cells > obs.cells: | |
split_obs = self.add_observation( | |
Observation( | |
ov_obs.cells - obs.cells, | |
ov_obs.num_mines - obs.num_mines, | |
) | |
) | |
print('subset ', obs) | |
print('superset', ov_obs) | |
print('split ', split_obs) | |
print() | |
else: | |
shared_cells = obs.cells & ov_obs.cells | |
if obs.num_mines == 1 and len(shared_cells) > 1: | |
ov_only_cells = ov_obs.cells - shared_cells | |
occluded_mines = ov_obs.num_mines - obs.num_mines | |
if occluded_mines == len(ov_only_cells): | |
self.add_observation( | |
Observation(ov_only_cells, occluded_mines) | |
) | |
def click(self, cell: Cell) -> None: | |
#XXX############################################################## | |
# cell.number = open(cell.y, cell.x) | |
try: | |
cell.number = open(cell.y, cell.x) | |
except Exception: | |
cell.val = '*' | |
print(self.board, '\n') | |
raise | |
self.changed_cells.add(cell) | |
def right_click(self, cell: Cell) -> None: | |
cell.flag() | |
self.changed_cells.add(cell) | |
def act(self) -> bool: | |
did_act = False | |
self.actions = set() | |
for obs in tuple(self.observations): | |
if obs.num_mines == len(obs.cells): | |
for cell in tuple(obs.cells): | |
if cell in self.changed_cells: | |
continue | |
##XXX############################################## | |
with self.debug('flagging', repr(cell), '\n for', obs): | |
self.right_click(cell) | |
##XXX############################################## | |
# self.right_click(cell) | |
did_act = True | |
self.remove_observation(obs) | |
elif obs.num_mines == 0: | |
for cell in tuple(obs.cells): | |
if cell in self.changed_cells: | |
continue | |
##XXX############################################## | |
with self.debug('clicking', repr(cell), '\n for', obs): | |
self.click(cell) | |
##XXX############################################## | |
# self.click(cell) | |
did_act = True | |
self.remove_observation(obs) | |
return did_act | |
#XXX################################################################## | |
def print_debug(self, *print_message) -> None: | |
if print_message: | |
print(*print_message) | |
print(f'{len(self.observations)=}') | |
print(self.board.debug()) | |
print() | |
from contextlib import contextmanager | |
@contextmanager | |
def debug(self, *print_message): | |
if print_message: | |
print(*print_message) | |
print(f'{len(self.observations)=}') | |
before = self.board.debug() | |
try: | |
yield | |
finally: | |
after = self.board.debug() | |
before_lines = before.splitlines() | |
after_lines = after.splitlines() | |
middle = len(before_lines) // 2 | |
bna_lines = [ | |
b_line + (' -> ' if i == middle else ' ') + a_line | |
for i, (b_line, a_line) in enumerate(zip(before_lines, after_lines)) | |
] | |
print('\n'.join(bna_lines), '\n') | |
for obs in sorted(self.observations, key=lambda obs: (min((c.x, c.y) for c in obs.cells)) if obs.cells else (-1,-1)): | |
print(obs) | |
print('---\n') | |
#XXX################################################################## | |
def solve_mine(map, n): | |
board = Board.from_map(map, n) | |
solver = Solver(board) | |
solver.act() | |
while board.mines_remaining > 0: | |
solver.apply_cell_changes() | |
solver.clear_inferred_observations() | |
solver.simplify_observations() | |
with solver.debug(): | |
did_act = solver.act() | |
if not did_act: | |
print() | |
for obs in sorted(solver.observations, key=lambda obs: min((c.x, c.y) for c in obs.cells)): | |
print(obs) | |
return '?' | |
#XXX############################################################## | |
#XXX############################################################## | |
# if not solver.act(): | |
# print(board.debug()) #XXX########################## | |
# return '?' | |
return str(board) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment