Created
July 2, 2010 03:52
-
-
Save aflaxman/460904 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pylab import * | |
import random | |
index_set = [[i,j] for i in range(9) for j in range(9)] | |
def solve(T): | |
""" Find a solution to T, if possible | |
T is a 9x9 array, with blank cells set to -1 | |
T is changed to the solution, returns 'success' or 'failure' | |
Example | |
------- | |
>>> T = rand(70) | |
>>> solve(T) | |
'success' | |
>>> any(T == -1) | |
False | |
""" | |
# if all cells are filled in, we win | |
if all(T > 0): | |
return 'success' | |
# solve T recursively, by trying all values for the most constrained var | |
pos = possibilities(T) | |
i,j = most_constrained(T, pos) | |
for val in pos[(i,j)]: | |
T[i,j] = val | |
if solve(T) == 'success': | |
return 'success' | |
# if this point is reached, this branch is unsatisfiable | |
T[i, j] = -1 | |
return 'failure' | |
def count_solns(T, one_vs_many=False): | |
""" How many unique solutions are there starting with T? | |
if one_vs_many is True, just count if there are 0, 1, or many solutions | |
Example | |
------- | |
>>> T = rand(81) | |
>>> count_solns(T) | |
1 | |
""" | |
# solve T recursively, by trying all values for the most constrained var | |
pos = possibilities(T) | |
# if there are no keys in the possibility dictionary, this is a solution | |
if pos.keys() == []: | |
return 1 | |
i,j = most_constrained(T, pos) | |
count = 0 | |
for val in pos[(i,j)]: | |
T[i, j] = val | |
count += count_solns(T, one_vs_many) | |
if one_vs_many and count > 1: | |
T[i, j] = -1 | |
return count | |
# when this point is reached, reset most_constrained cell | |
T[i, j] = -1 | |
return count | |
def rand(n, T=None): | |
""" Create a random game, with n cells filled in | |
optionally start with an initialized board T | |
Example | |
------- | |
>>> sum(rand(70) != -1) | |
70 | |
""" | |
# start with an empty board | |
if T == None: | |
T = -1*ones([9,9]) | |
# solve it to generate an initial solution | |
res = solve(T) | |
assert res == 'success' | |
# do random shuffles to approximate uniformly random solution | |
for k in range(5): | |
select_random_cells(T, 20) | |
randomly_permute_labels(T) | |
solve(T) | |
# remove appropriate amount of labels | |
select_random_cells(T, n) | |
return T | |
def most_constrained(T, pos): | |
""" Find blank cell which is most constrained by non-blank cells | |
Returns tuple indexing the cell | |
""" | |
most_value = inf | |
for i, j in index_set: | |
if T[i, j] < 0: | |
cur_value = len(pos[(i,j)]) | |
if cur_value < most_value: | |
most_index = (i,j) | |
most_value = cur_value | |
return most_index | |
def possibilities(T): | |
""" Find all possibilities for each empty cell of T | |
Returns a set of dictionaries | |
Example | |
------- | |
>>> pos = possibilities(-1*ones([9,9])) | |
>>> pos[(0,0)] | |
set([1, 2, 3, 4, 5, 6, 7, 8, 9]) | |
""" | |
pos = {} | |
for i, j in index_set: | |
# integer division to find the super-cell for this i and j | |
ci = int(i)/int(3) | |
cj = int(j)/int(3) | |
if T[i, j] < 0: | |
pos[(i,j)] = set(range(1,10)) - (set(T[i, :]) | set(T[:, j]) \ | |
| set(T[(3*ci):(3*ci+3), (3*cj):(3*cj+3)].flatten())) | |
return pos | |
def select_random_cells(T, n): | |
""" Replace all but n cells with -1 to indicate they are blank""" | |
for i,j in random.sample(index_set, 81-n): | |
T[i, j] = -1 | |
def randomly_permute_labels(T): | |
""" Permute the positive values of T uniformly at random""" | |
new_labels = range(1,10) | |
random.shuffle(new_labels) # random.shuffle acts in-place | |
new_label_dict = dict(zip(range(1,10), new_labels)) | |
for i,j in index_set: | |
if T[i,j] > 0: | |
T[i,j] = new_label_dict[T[i,j]] | |
def draw(T, R=None): | |
""" Use matplotlib to display 9x9 table T in a Sudoku style | |
Example | |
------- | |
>>> T = rand(70) | |
>>> R = copy(T) | |
>>> solve(R) | |
'success' | |
>>> draw(T, R) | |
""" | |
clf() | |
params = dict(linewidth=2) | |
grid = [1./3., 2./3.] | |
hlines(grid, 0, 1, **params) | |
vlines(grid, 0, 1, **params) | |
params = dict(linewidth=1) | |
grid = arange(0., 1.1, 1./9.) | |
hlines(grid, 0, 1, **params) | |
vlines(grid, 0, 1, **params) | |
params = dict(facecolor='gray') | |
mid_verts = [1/3., 2/3., 2/3., 1/3.] | |
bot_verts = [0, 0, 1/3., 1/3.] | |
top_verts = [1, 1, 2/3., 2/3.] | |
fill(mid_verts, bot_verts, **params) | |
fill(mid_verts, top_verts, **params) | |
fill(bot_verts, mid_verts, **params) | |
fill(top_verts, mid_verts, **params) | |
params = dict(fontsize=20, ha='center', va='center') | |
for i in range(9): | |
for j in range(9): | |
row_pos = 1. - (i/9. + 1/18.) | |
col_pos = j/9. + 1/18. | |
if T[i,j] > 0: | |
text(col_pos, row_pos, '%d'%T[i,j], weight='bold', **params) | |
elif R != None: | |
text(col_pos, row_pos, '%d'%R[i,j], **params) | |
axis([0,1,0,1]) | |
xticks([]) | |
yticks([]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment