Last active
January 17, 2019 19:13
-
-
Save ssylvan/fef75ea1decdb62fe78be2b8f786eecb to your computer and use it in GitHub Desktop.
Sudoku solver
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <intrin.h> | |
#include <assert.h> | |
#include <vector> | |
#include <chrono> | |
#include <memory> | |
int get_set_bit(uint16_t x) { | |
assert(__popcnt16(x) == 1); | |
unsigned long index; | |
_BitScanForward(&index, x); | |
return index; | |
} | |
// Go through all peers. x and y are implied to contain the current square, | |
// and the body will be run with x and y set to each element in turn. | |
// After the loop has finished, x and y will not have been modified. | |
#define FOREACH_PEER_ELEM(body) {\ | |
int currentx = x; int currenty = y; \ | |
for (int x = 0; x < 9; ++x) { if (x != currentx) body; } \ | |
for (int y = 0; y < 9; ++y) { if (y != currenty) body; } \ | |
for(int x = (currentx/3)*3; x < (currentx/3 + 1)*3; ++x) {\ | |
for(int y = (currenty/3)*3; y < (currenty/3 + 1)*3; ++y) {\ | |
if (x != currentx && y != currenty) body\ | |
} \ | |
} \ | |
} | |
#define FOREACH_SET_BIT(digit_mask_input, body) { \ | |
unsigned long bit_index; uint16_t digit_mask = (digit_mask_input); \ | |
while (_BitScanForward(&bit_index, digit_mask)) { \ | |
digit_mask &= digit_mask-1; body\ | |
}\ | |
} | |
bool is_exactly_one_bit_set(uint16_t x) { | |
assert(x > 0); | |
return (x & (x - 1)) == 0; | |
} | |
struct board { | |
// Contains a 9-bit mask per square, indicating possible digits in each square | |
uint16_t squares[9][9]; | |
// For each "unit" (rows, cols, blocks) and each digit (0..8), store a bit mask | |
// indicating where this digit is still possible. This could be reconstructed from | |
// the squares array above, but it's faster to maintain a redundant "reverse index". | |
uint16_t possible_digit_locations_blocks[3][3][9]; | |
uint16_t possible_digit_locations_rows[9][9]; | |
uint16_t possible_digit_locations_cols[9][9]; | |
board() { | |
std::fill_n((uint16_t*)squares, 81, (uint16_t)0x1ff); | |
std::fill_n((uint16_t*)possible_digit_locations_rows, 9 * 9, (uint16_t)0x1FF); | |
std::fill_n((uint16_t*)possible_digit_locations_cols, 9 * 9, (uint16_t)0x1FF); | |
std::fill_n((uint16_t*)possible_digit_locations_blocks, 9 * 9, (uint16_t)0x1FF); | |
} | |
uint8_t next_square_candidate_x = 0, next_square_candidate_y = 0; | |
bool eliminate(int x, int y, uint16_t digits_to_eliminate) { | |
digits_to_eliminate &= squares[x][y]; // Only eliminate digits that exists in square | |
if (digits_to_eliminate == 0) { | |
return true; // already eliminated. | |
} | |
squares[x][y] &= ~digits_to_eliminate; // clear digit | |
uint16_t remaining_digits = squares[x][y]; | |
if (remaining_digits == 0) { | |
return false; // contradiction found, no possible digits left. | |
} | |
int block_x = x / 3; | |
int block_y = y / 3; | |
int block_bit_index = (x%3) + 3*(y%3); | |
// Clear out the "reverse index" | |
FOREACH_SET_BIT(digits_to_eliminate, | |
assert(possible_digit_locations_cols[x][bit_index] & (1 << y)); | |
possible_digit_locations_cols[x][bit_index] &= ~(1 << y); | |
if (possible_digit_locations_cols[x][bit_index] == 0) { | |
return false; | |
} | |
assert(possible_digit_locations_rows[y][bit_index] & (1 << x)); | |
possible_digit_locations_rows[y][bit_index] &= ~(1 << x); | |
if (possible_digit_locations_rows[y][bit_index] == 0) { | |
return false; | |
} | |
assert(possible_digit_locations_blocks[block_x][block_y][bit_index] & (1 << block_bit_index)); | |
possible_digit_locations_blocks[block_x][block_y][bit_index] &= ~(1 << block_bit_index); | |
if (possible_digit_locations_blocks[block_x][block_y][bit_index] == 0) { | |
return false; | |
} | |
); | |
// If we've eliminated all but one digit, then we should eliminate that digit from all the peers. | |
if (is_exactly_one_bit_set(remaining_digits)) { | |
int remaining_digit_index = get_set_bit(remaining_digits); | |
// Get all the positions where this digit is set in the current row, column and block, | |
// and eliminate them from those squares. | |
// Start with the current row. | |
uint16_t remaining_pos_mask = possible_digit_locations_rows[y][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << x); // Don't eliminate from the current square. | |
FOREACH_SET_BIT(remaining_pos_mask, | |
if (!eliminate(bit_index,y, remaining_digits)) { | |
return false; | |
} | |
) | |
// Next eliminate it from the column. | |
remaining_pos_mask = possible_digit_locations_cols[x][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << y); // Don't eliminate from the current square | |
FOREACH_SET_BIT(remaining_pos_mask, | |
if (!eliminate(x, bit_index, remaining_digits)) { | |
return false; | |
} | |
) | |
// Next eliminate it from the current block | |
remaining_pos_mask = possible_digit_locations_blocks[block_x][block_y][remaining_digit_index]; | |
remaining_pos_mask &= ~(1 << block_bit_index); // Don't eliminate from the current square | |
FOREACH_SET_BIT(remaining_pos_mask, | |
int x_offset = bit_index % 3; | |
int y_offset = bit_index / 3; | |
if (!eliminate(block_x*3 + x_offset, block_y *3 + y_offset, remaining_digits)) { | |
return false; | |
} | |
) | |
} | |
// For each digit we just eliminated, find if it now only has one remaining posible location | |
// in either the row, column or block. If so, set the digit | |
FOREACH_SET_BIT(digits_to_eliminate, | |
// Check the row | |
if (is_exactly_one_bit_set(possible_digit_locations_rows[y][bit_index])) { | |
int digit_x = get_set_bit(possible_digit_locations_rows[y][bit_index]); | |
if (!set_digit(digit_x, y, bit_index)) { | |
return false; | |
} | |
} | |
// Column | |
if (is_exactly_one_bit_set(possible_digit_locations_cols[x][bit_index])) { | |
int digit_y = get_set_bit(possible_digit_locations_cols[x][bit_index]); | |
if (!set_digit(x, digit_y, bit_index)) { | |
return false; | |
} | |
} | |
// Block | |
if (is_exactly_one_bit_set(possible_digit_locations_blocks[block_x][block_y][bit_index])) { | |
int bit_with_digit = get_set_bit(possible_digit_locations_blocks[block_x][block_y][bit_index]); | |
int x_offset = bit_with_digit % 3; | |
int y_offset = bit_with_digit / 3; | |
if (!set_digit(block_x*3 + x_offset, block_y *3 + y_offset, bit_index)) { | |
return false; | |
} | |
} | |
) | |
// While we're here, check if the square has a popcnt of 2, this means it's has the minimum | |
// number of possibilities without being "done", making it an excellent candidate for the next | |
// trial assignment. | |
if (__popcnt16(squares[x][y]) == 2) { | |
next_square_candidate_x = x; | |
next_square_candidate_y = y; | |
} | |
return true; | |
} | |
bool set_digit(int x, int y, int d) { | |
int digit_mask = 1 << d; | |
assert(squares[x][y] & digit_mask); | |
if (!eliminate(x, y, ~digit_mask)) { | |
return false; | |
} | |
assert(squares[x][y] == digit_mask); | |
return true; | |
} | |
bool find_next_square_to_assign(int& x, int& y) const { | |
// Check most recent squares to have been found to have 2 digits, if it still | |
// does, just use that one since it would be a minimum. | |
if (__popcnt16(squares[next_square_candidate_x][next_square_candidate_y]) == 2) { | |
x = next_square_candidate_x; | |
y = next_square_candidate_y; | |
return true; | |
} | |
int min_count = 10; | |
for (int i = 0; i < 9; ++i) { | |
for (int j = 0; j < 9; ++j) { | |
int c = __popcnt16(squares[i][j]); | |
// Again, early out on popcnt == 2, since that would be a minimum | |
if (c == 2) { | |
x = i; | |
y = j; | |
return true; | |
} | |
// Else find smallest | |
if (c > 1 && c < min_count) { | |
x = i; | |
y = j; | |
min_count = c; | |
} | |
} | |
} | |
return min_count != 10; // Did we find a square to process? | |
} | |
void print() { | |
for (int y = 0; y < 9; ++y) { | |
for (int x = 0; x < 9; ++x) { | |
printf("["); | |
for (int i = 0; i < 9; ++i) { | |
if (squares[x][y] & (1 << i)) { | |
printf("%d", i + 1); | |
} | |
else { | |
printf(" "); | |
} | |
} | |
printf("] "); | |
} | |
printf("\n"); | |
} | |
} | |
bool is_solved() const { | |
for (int y = 0; y < 9; ++y) { | |
for (int x = 0; x < 9; ++x) { | |
if (!is_exactly_one_bit_set(squares[x][y])) | |
return false; | |
// Make sure this digit isn't in any of the peers | |
int d = get_set_bit(squares[x][y]); | |
FOREACH_PEER_ELEM({ | |
if (get_set_bit(squares[x][y]) == d) | |
return false; | |
}) | |
} | |
} | |
return true; | |
} | |
}; | |
bool search(const board& current_board, board& final_board) { | |
// Find the next square to do a trial assignment on | |
int x, y; | |
if (current_board.find_next_square_to_assign(x, y)) { | |
uint16_t digits = current_board.squares[x][y]; | |
// Then assign each possible digit to this square | |
FOREACH_SET_BIT(digits, { | |
board board_copy = current_board; | |
// If we can successfully set this digit, do search from here | |
if (board_copy.set_digit(x, y, bit_index)) { | |
if (search(board_copy, final_board)) { | |
return true; | |
} | |
} | |
}) | |
} | |
else { | |
// No more squares to assign, so we're done! | |
assert(current_board.is_solved()); | |
final_board = current_board; | |
return true; | |
} | |
return false; | |
} | |
int main() | |
{ | |
// Load sudoku grids | |
FILE* sudoku_file; | |
if (fopen_s(&sudoku_file, "sudoku17.txt", "r")) { | |
printf("Failed to open sudoku file"); | |
return -1; | |
} | |
std::vector<std::unique_ptr<const char[]>> lines; | |
while (!feof(sudoku_file)) { | |
char line_buf[82]; | |
fgets(line_buf, _countof(line_buf), sudoku_file); | |
size_t n = strlen(line_buf); | |
if (n < 81) | |
continue; | |
lines.emplace_back(_strdup(line_buf)); | |
} | |
fclose(sudoku_file); | |
for (int runs = 0; runs < 1; ++runs) { // Increase the loop count here for performance profiling | |
std::chrono::high_resolution_clock clock; | |
// Solve all the grids | |
auto start = clock.now(); | |
std::vector<board> boards; | |
for (size_t board_ix = 0; board_ix < lines.size(); ++board_ix) { | |
board b; | |
for (int row = 0; row < 9; ++row) { | |
for (int col = 0; col < 9; ++col) { | |
char c = lines[board_ix][row * 9 + col]; | |
if (c > '0' && c <= '9') { | |
bool res = b.set_digit(col, row, c - '1'); // we use zero-based digits, hence the '1' | |
assert(res); | |
} | |
} | |
} | |
boards.push_back(b); | |
} | |
for (int i = 0; i < boards.size(); ++i) { | |
board final_board; | |
if (!search(boards[i], final_board)) { | |
printf("Failed to solve grid %d\n", i); | |
boards[i].print(); | |
} | |
} | |
auto end = clock.now(); | |
uint64_t nanosecs = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count(); | |
printf("Took %.1f milliseconds for %d puzzles, or %.2f microseconds average\n", (float)nanosecs / 1000000, (int)boards.size(), (float)nanosecs / (1000.0f*boards.size())); | |
} | |
return 0; | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment