Skip to content

Instantly share code, notes, and snippets.

@tomviner
Forked from teh/sudoku.py
Last active September 20, 2016 20:26
Show Gist options
  • Save tomviner/f3024ace90da09726b84 to your computer and use it in GitHub Desktop.
Save tomviner/f3024ace90da09726b84 to your computer and use it in GitHub Desktop.
8 Queens problem with a SAT solver
import pycosat
import numpy
import itertools
blank_idx = 0
queen_idx = 1
def get_cnf():
# * add one because 0 is reserved in picosat
# * object type because pycosat expects 'int's
# * 2*8^2 vars x_ij^d where (i, j) == row and col, d == digit
vars = (numpy.arange(W * H * D).reshape(W, H, D) + 1).astype('object')
cnf = []
# At least one digit per square
for i in xrange(H):
for j in xrange(W):
cnf.append(vars[i, j, :].tolist())
# Only one digit per square
for i in xrange(H):
for j in xrange(W):
cnf += list(itertools.combinations(-vars[i, j, :], 2))
# FIXME: Want to specify: board must contain no fewer than H queens
# Problem: these ideas either don't work (i.e. are mistaken), or don't scale up to even 8 queens
# Ideas:
# - exclude all combinations of blanks where there's less than H queens left:
# i.e. whole board minus the H queens, plus 1 = H * W - H + 1
# list(itertools.combinations(-vars[:, :, blank_idx].ravel(), H * W - H + 1))
#
# - specify all ways of having at least H queens:
# list(itertools.combinations(vars[:, :, queen_idx].ravel(), H))
#
# Instead I use this proxy for the number of queens requirement:
# Each row and each column must contain at least one queen
for i in xrange(H): # H must equal W here
cnf.append(vars[i, :, queen_idx].tolist())
cnf.append(vars[:, i, queen_idx].tolist())
# Each row and each column must contain no more than one queen
for i in xrange(H): # H must equal W here
cnf += list(itertools.combinations(-vars[i, :, queen_idx].ravel(), 2))
cnf += list(itertools.combinations(-vars[:, i, queen_idx].ravel(), 2))
# Each diagonal must contain no more than one queen
for offset in xrange(-H, H):
diag_backslash = []
for i in xrange(H):
for j in xrange(W):
if i == j + offset:
diag_backslash.append(-vars[i, j, queen_idx])
cnf += list(itertools.combinations(diag_backslash, 2))
diag_slash = []
for i in xrange(H):
for j in xrange(W):
if i + j + 1 == H + offset:
diag_slash.append(-vars[i, j, queen_idx])
cnf += list(itertools.combinations(diag_slash, 2))
return [list(x) for x in cnf]
def print_solution(solution):
assert solution != 'UNSAT', solution
solution_a = numpy.array(solution).reshape(W, H, D)
for i in xrange(H):
for j in xrange(W):
for d in xrange(D):
if solution_a[i, j, d] > 0:
if d == queen_idx:
print 'Q',
else:
print '-',
print ""
for n in range(8, 40, 5) + [8]:
W = n
H = n
D = 2
cnf = get_cnf()
# print cnf
print '---', n
print_solution(pycosat.solve(cnf))
import pycosat
import numpy
import itertools
WORLDS_HARDEST_RIDDLE_ACCORDING_TO_TELEGRAPH = """\
8........
..36.....
.7..9.2..
.5...7...
....457..
...1...3.
..1....68
..85...1.
.9....4.."""
def get_cnf(riddle):
# * add one because 0 is reserved in picosat
# * object type because pycosat expects 'int's
# * 9^3 vars x_ij^d where (i, j) == row and col, d == digit
vars = (numpy.arange(9 * 9 * 9).reshape(9, 9, 9) + 1).astype('object')
cnf = []
# At least one digit per square
for i in xrange(9):
for j in xrange(9):
cnf.append(vars[i, j, :].tolist())
# Only one digit per square
for i in xrange(9):
for j in xrange(9):
cnf += list(itertools.combinations(-vars[i, j, :], 2))
# Each 3x3 block must contain 9 differrent digits
for i in xrange(3):
for j in xrange(3):
for d in xrange(9):
cnf += list(itertools.combinations(-vars[i*3:i*3+3, j*3:j*3+3, d].ravel(), 2))
# Each row and each column must contain 9 different digits
for i in xrange(9):
for d in xrange(9):
cnf += list(itertools.combinations(-vars[i,:,d].ravel(), 2))
cnf += list(itertools.combinations(-vars[:,i,d].ravel(), 2))
# Tranform riddle board to CNF
for i, x in enumerate(riddle.split()):
for j, y in enumerate(x):
if y == '.':
continue
d = int(y) - 1
cnf.append([vars[i, j, d]])
return [list(x) for x in cnf]
def print_solution(solution):
solution_a = numpy.array(solution).reshape(9, 9, 9)
for i in xrange(9):
for j in xrange(9):
for d in xrange(9):
if solution_a[i, j, d] > 0:
print d + 1,
print ""
print_solution(pycosat.solve(get_cnf(WORLDS_HARDEST_RIDDLE_ACCORDING_TO_TELEGRAPH)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment