Skip to content

Instantly share code, notes, and snippets.

@blackle
Created July 17, 2024 18:28
Show Gist options
  • Save blackle/39bcebcb718e1baed93f60ecd20b2cc0 to your computer and use it in GitHub Desktop.
Save blackle/39bcebcb718e1baed93f60ecd20b2cc0 to your computer and use it in GitHub Desktop.
Make sudoku with cryptominisat in python
import pycryptosat as pysat
import itertools
# this is a helper class to save the results of "solve"
# and to give an easy way to make new variables
class Solver(pysat.Solver):
def __init__(self):
pysat.Solver.__init__(self)
self.n_vars = 0
self.sat = None
self.solution = None
def make_vars(self, n):
v = [x for x in range(self.n_vars+1,self.n_vars+1+n)]
self.n_vars += n
return v
def solve(self):
self.sat, self.solution = pysat.Solver.solve(self)
def at_least_one(solver, vs):
solver.add_clause(vs)
def at_most_one(solver, vs):
# iterate over all pairs
for x, y in itertools.combinations(vs, 2):
# x implies not y
solver.add_clause([-x, -y])
def exactly_one(solver, vs):
at_least_one(solver, vs)
at_most_one(solver, vs)
class Digit:
def __init__(self, solver, max_size):
self.vars = solver.make_vars(max_size)
self.solver = solver
exactly_one(self.solver, self.vars)
def value(self):
return next(i for i, x in enumerate(self.vars) if self.solver.solution[x])
def get_var(self, n):
return self.vars[n]
def make_digits_unique(solver, digits):
max_size = 9
for i in range(max_size):
# make each "i" variable of the digits unique
exactly_one(solver, [d.get_var(i) for d in digits])
solver = Solver()
board = [[Digit(solver, 9) for i in range(9)] for i in range(9)]
# make rows unique
for row in board:
make_digits_unique(solver, row)
# make columns unique
for idx in range(9):
col = [r[idx] for r in board]
make_digits_unique(solver, col)
# make blocks unique
for i in range(3):
for j in range(3):
block = [board[i*3+k][j*3+l] for k in range(3) for l in range (3)]
make_digits_unique(solver, block)
solver.solve()
# print the board
for i, row in enumerate(board):
for j, digit in enumerate(row):
print(digit.value()+1, end=" ")
if j != 8 and j % 3 == 2:
print("|", end=" ")
print("")
if i != 8 and i % 3 == 2:
print("-"*21)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment