Last active
November 14, 2023 22:29
-
-
Save skatenerd/fecf20ef9a4159fe16a915e14c854422 to your computer and use it in GitHub Desktop.
Sudoku solver
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import functools | |
class Subgrid: | |
def __init__(self, row, column): | |
self.row = row | |
self.column = column | |
def indices(self): | |
return [ | |
(row + (self.row * 3), column + (self.column * 3)) | |
for row in range(3) | |
for column in range(3) | |
] | |
class Row: | |
def __init__(self, index): | |
self.index = index | |
def indices(self): | |
return [(self.index, c) for c in range(9)] | |
class Column: | |
def __init__(self, index): | |
self.index = index | |
def indices(self): | |
return [(r, self.index) for r in range(9)] | |
def all_entities(): | |
return ( | |
[ | |
Subgrid(row,col) | |
for row in range(3) | |
for col in range(3) | |
] + | |
[ | |
Row(x) for x in range(9) | |
] + | |
[ | |
Column(x) for x in range(9) | |
] | |
) | |
class Observation: | |
def __init__(self, owner, number, squares): | |
self.owner = owner | |
self.number = number | |
self.squares = squares | |
class Game: | |
def __init__(self, initial_values): | |
initial_state = [ | |
[set(range(1,10)) for col in range(9)] | |
for row in | |
range(9) | |
] | |
for ((row,col),value) in initial_values: | |
initial_state[row][col] = {value} | |
self.state = initial_state | |
def entities_containing(self, squares): | |
return [e for e in all_entities() if squares.issubset(e.indices())] | |
def observations_for_entity(self, entity, cluster_size): | |
all_numbers = functools.reduce(set.union, [self.state[r][c] for (r,c) in entity.indices()]) | |
supersets = [set(c) for c in itertools.combinations(all_numbers, cluster_size)] | |
for superset in supersets: | |
# find a set of N numbers where this "entity" has N cells which fall within that set | |
hits = [(r,c) for (r,c) in entity.indices() if self.state[r][c].issubset(superset)] | |
if len(hits) >= cluster_size: | |
for number in superset: | |
yield Observation(entity, number, hits[:cluster_size]) | |
for n in range(1,10): | |
hits = [(r,c) for (r,c) in entity.indices() if self.state[r][c].issuperset({n})] | |
# Find any numbers which only are legal in a small region (small enough to be interesting) | |
if len(hits) <= 3 and len(hits) >= 1: | |
yield Observation(entity, n, hits) | |
def iterate(self, cluster_size): | |
start_size = sum(len(cell) for cell in (self.state[r][c] for r in range(9) for c in range(9))) | |
for entity in all_entities(): | |
observations = self.observations_for_entity(entity, cluster_size) | |
for observation in observations: | |
self.apply_observation(observation) | |
end_size = sum(len(cell) for cell in (self.state[r][c] for r in range(9) for c in range(9))) | |
return start_size != end_size | |
def apply_observation(self, observation): | |
for e in self.entities_containing(set(observation.squares)): | |
for (r,c) in set(e.indices()) - set(observation.squares): | |
self.state[r][c] = self.state[r][c] - {observation.number} | |
if len(observation.squares) == 1: | |
first, *rest = observation.squares | |
row, col = first | |
self.state[row][col] = {observation.number} | |
def done(self): | |
return not any(len(self.state[r][c]) > 1 for r in range(9) for c in range(9)) | |
def solve(self, cluster_size): | |
if self.done(): | |
return | |
if self.iterate(cluster_size): | |
self.solve(1) | |
else: | |
self.solve(cluster_size + 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment