Created
October 27, 2024 07:16
-
-
Save sshark/3144d8e1a16ff5a652697e8ec6617129 to your computer and use it in GitHub Desktop.
Sudoku Runtime Performance Java vs Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package pnorvig.ipynb; | |
import java.io.BufferedReader; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.concurrent.CountDownLatch; | |
import java.util.stream.IntStream; | |
//////////////////////////////// Solve Sudoku Puzzles //////////////////////////////// | |
//////////////////////////////// @author Peter Norvig //////////////////////////////// | |
/** | |
* There are two representations of puzzles that we will use: | |
* * 1. A gridstring is 81 chars, with characters '0' or '.' for blank and '1' to '9' for digits. | |
* * 2. A puzzle grid is an int[81] with a digit d (1-9) represented by the integer (1 << (d - 1)); | |
* * that is, a bit pattern that has a single 1 bit representing the digit. | |
* * A blank is represented by the OR of all the digits 1-9, meaning that any digit is possible. | |
* * While solving the puzzle, some of these digits are eliminated, leaving fewer possibilities. | |
* * The puzzle is solved when every square has only a single possibility. | |
* * | |
* * Search for a solution with `search`: | |
* * - Fill an empty square with a guessed digit and do constraint propagation. | |
* * - If the guess is consistent, search deeper; if not, try a different guess for the square. | |
* * - If all guesses fail, back up to the previous level. | |
* * - In selecting an empty square, we pick one that has the minimum number of possible digits. | |
* * - To be able to back up, we need to keep the grid from the previous recursive level. | |
* * But we only need to keep one grid for each level, so to save garbage collection, | |
* * we pre-allocate one grid per level (there are 81 levels) in a `gridpool`. | |
* * Do constraint propagation with `arcConsistent`, `dualConsistent`, and `nakedPairs`. | |
**/ | |
public class Sudoku { | |
//////////////////////////////// main; command line options ////////////////////////////// | |
static final String usage = String.join("\n", | |
"usage: java Sudoku -(no)[fghnprstuv] | -[RT]<number> | <filename> ...", | |
"E.g., -v turns verify flag on, -nov turns it off. -R and -T require a number. The options:\n", | |
" -f(ile) Print summary stats for each file (default on)", | |
" -g(rid) Print each puzzle grid and solution grid (default off)", | |
" -h(elp) Print this usage message", | |
" -n(aked) Run naked pairs (default on)", | |
" -p(uzzle) Print summary stats for each puzzle (default off)", | |
" -r(everse) Solve the reverse of each puzzle as well as each puzzle itself (default off)", | |
" -s(earch) Run search (default on, but some puzzles can be solved with CSP methods alone)", | |
" -t(hread) Print summary stats for each thread (default off)", | |
" -u(nitTest)Run a suite of unit tests (default off)", | |
" -v(erify) Verify each solution is valid (default on)", | |
" -T<number> Concurrently run <number> threads (default 26)", | |
" -R<number> Repeat each puzzle <number> times (default 1)", | |
" <filename> Solve all puzzles in filename, which has one puzzle per line"); | |
boolean printFileStats = true; // -f | |
boolean printGrid = false; // -g | |
boolean runNakedPairs = false; // -n | |
boolean printPuzzleStats = false; // -p | |
boolean reversePuzzle = false; // -r | |
boolean runSearch = true; // -s | |
boolean printThreadStats = false; // -t | |
boolean verifySolution = true; // -v | |
int nThreads = 26; // -T | |
int repeat = 1; // -R | |
int backtracks = 0; // count total backtracks | |
/** | |
* Parse command line args and solve puzzles in files. | |
**/ | |
public static void main(String[] args) throws IOException { | |
mainWith(new Sudoku(), args); | |
} | |
public static void mainWith(Sudoku s, String[] args) throws IOException { | |
for (String arg : args) { | |
if (!arg.startsWith("-")) { | |
s.solveFile(arg); | |
} else { | |
boolean value = !arg.startsWith("-no"); | |
switch (arg.charAt(value ? 1 : 3)) { | |
case 'f': | |
s.printFileStats = value; | |
break; | |
case 'g': | |
s.printGrid = value; | |
break; | |
case 'h': | |
System.out.println(usage); | |
break; | |
case 'n': | |
s.runNakedPairs = value; | |
break; | |
case 'p': | |
s.printPuzzleStats = value; | |
break; | |
case 'r': | |
s.reversePuzzle = value; | |
break; | |
case 's': | |
s.runSearch = value; | |
break; | |
case 't': | |
s.printThreadStats = value; | |
break; | |
case 'u': | |
s.runUnitTests(); | |
break; | |
case 'v': | |
s.verifySolution = value; | |
break; | |
case 'T': | |
s.nThreads = Integer.parseInt(arg.substring(2)); | |
break; | |
case 'R': | |
s.repeat = Integer.parseInt(arg.substring(2)); | |
break; | |
default: | |
System.out.println("Unrecognized option: " + arg + "\n" + usage); | |
} | |
} | |
} | |
} | |
//////////////////////////////// Handling Lists of Puzzles //////////////////////////////// | |
/** | |
* Solve all the puzzles in a file. Report timing statistics. | |
**/ | |
public void solveFile(String filename) throws IOException { | |
List<int[]> grids = readFile(filename); | |
long startFileTime = System.nanoTime(); | |
if (nThreads == 1) { | |
solveList(grids); | |
} else { | |
solveListThreaded(grids, nThreads); | |
} | |
if (printFileStats) printStats(grids.size() * repeat, startFileTime, filename); | |
} | |
/** | |
* Solve a list of puzzles in a single thread. | |
* * repeat -R<number> times; print each puzzle's stats if -p; print grid if -g; verify if -v. | |
**/ | |
void solveList(List<int[]> grids) { | |
int[] puzzle = new int[N * N]; // Used to save a copy of the original grid | |
int[][] gridpool = new int[N * N][N * N]; // Reuse grids during the search | |
for (int g = 0; g < grids.size(); ++g) { | |
int[] grid = grids.get(g); | |
System.arraycopy(grid, 0, puzzle, 0, grid.length); | |
for (int i = 0; i < repeat; ++i) { | |
long startTime = printPuzzleStats ? System.nanoTime() : 0; | |
int[] solution = initialize(grid); // All the real work is | |
if (runSearch) solution = search(solution, gridpool, 0); // on these 2 lines. | |
if (printPuzzleStats) { | |
printStats(1, startTime, rowString(grid, 0).replaceAll("[ |]", "")); | |
} | |
if (i == 0 && (printGrid || (verifySolution && !verify(solution, puzzle)))) { | |
printGrids("Puzzle " + (g + 1), grid, solution); | |
} | |
} | |
} | |
} | |
/** | |
* Break a list of puzzles into nThreads sublists and solve each sublist in a separate thread. | |
**/ | |
void solveListThreaded(List<int[]> grids, int nThreads) { | |
try { | |
int nGrids = grids.size(); | |
final CountDownLatch latch = new CountDownLatch(nThreads); | |
int size = nGrids / nThreads; | |
for (int c = 0; c < nThreads; ++c) { | |
int end = c == nThreads - 1 ? nGrids : (c + 1) * size; | |
final List<int[]> sublist = grids.subList(c * size, end); | |
new Thread(() -> { | |
final long startTime = System.nanoTime(); | |
solveList(sublist); | |
latch.countDown(); | |
if (printThreadStats) { | |
printStats(repeat * sublist.size(), startTime, Thread.currentThread().getName()); | |
} | |
}).start(); | |
} | |
latch.await(); // Wait for all threads to finish | |
} catch (InterruptedException e) { | |
System.err.println("And you may ask yourself, 'Well, how did I get here?'"); | |
} | |
} | |
//////////////////////////////// Utility functions //////////////////////////////// | |
/** | |
* Return an array of all squares in the intersection of these rows and cols | |
**/ | |
int[] cross(int[] rows, int[] cols) { | |
int[] result = new int[rows.length * cols.length]; | |
int i = 0; | |
for (int r : rows) { | |
for (int c : cols) { | |
result[i++] = N * r + c; | |
} | |
} | |
return result; | |
} | |
/** | |
* Return true iff item is an element of array, or of array[0:end]. | |
**/ | |
boolean member(int item, int[] array) { | |
return member(item, array, array.length); | |
} | |
boolean member(int item, int[] array, int end) { | |
for (int i = 0; i < end; ++i) { | |
if (array[i] == item) { | |
return true; | |
} | |
} | |
return false; | |
} | |
//////////////////////////////// Constants //////////////////////////////// | |
final int N = 9; // Number of cells on a side of grid. | |
final int[] DIGITS = {1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8}; | |
final int ALL_DIGITS = Integer.parseInt("111111111", 2); | |
final int[] ROWS = IntStream.range(0, N).toArray(); | |
final int[] COLS = ROWS; | |
final int[] SQUARES = IntStream.range(0, N * N).toArray(); | |
final int[][] BLOCKS = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; | |
final int[][] ALL_UNITS = new int[3 * N][]; | |
final int[][][] UNITS = new int[N * N][3][N]; | |
final int[][] PEERS = new int[N * N][20]; | |
final int[] NUM_DIGITS = new int[ALL_DIGITS + 1]; | |
final int[] HIGHEST_DIGIT = new int[ALL_DIGITS + 1]; | |
{ | |
// Initialize ALL_UNITS to be an array of the 27 units: rows, columns, and blocks | |
int i = 0; | |
for (int r : ROWS) { | |
ALL_UNITS[i++] = cross(new int[]{r}, COLS); | |
} | |
for (int c : COLS) { | |
ALL_UNITS[i++] = cross(ROWS, new int[]{c}); | |
} | |
for (int[] rb : BLOCKS) { | |
for (int[] cb : BLOCKS) { | |
ALL_UNITS[i++] = cross(rb, cb); | |
} | |
} | |
// Initialize each UNITS[s] to be an array of the 3 units for square s. | |
for (int s : SQUARES) { | |
i = 0; | |
for (int[] u : ALL_UNITS) { | |
if (member(s, u)) UNITS[s][i++] = u; | |
} | |
} | |
// Initialize each PEERS[s] to be an array of the 20 squares that are peers of square s. | |
for (int s : SQUARES) { | |
i = 0; | |
for (int[] u : UNITS[s]) { | |
for (int s2 : u) { | |
if (s2 != s && !member(s2, PEERS[s], i)) { | |
PEERS[s][i++] = s2; | |
} | |
} | |
} | |
} | |
// Initialize NUM_DIGITS[val] to be the number of 1 bits in the bitset val | |
// and HIGHEST_DIGIT[val] to the highest bit set in the bitset val | |
for (int val = 0; val <= ALL_DIGITS; val++) { | |
NUM_DIGITS[val] = Integer.bitCount(val); | |
HIGHEST_DIGIT[val] = Integer.highestOneBit(val); | |
} | |
} | |
//////////////////////////////// Search algorithm //////////////////////////////// | |
/** | |
* Search for a solution to grid. If there is an unfilled square, select one | |
* * and try--that is, search recursively--every possible digit for the square. | |
**/ | |
int[] search(int[] grid, int[][] gridpool, int level) { | |
if (grid == null) { | |
return null; | |
} | |
int s = select_square(grid); | |
if (s == -1) { | |
return grid; // No squares to select means we are done! | |
} | |
for (int d : DIGITS) { | |
// For each possible digit d that could fill square s, try it | |
if ((d & grid[s]) > 0) { | |
// Copy grid's contents into gridpool[level], and use that at the next level | |
System.arraycopy(grid, 0, gridpool[level], 0, grid.length); | |
int[] result = search(fill(gridpool[level], s, d), gridpool, level + 1); | |
if (result != null) { | |
return result; | |
} | |
backtracks += 1; | |
} | |
} | |
return null; | |
} | |
/** | |
* Verify that grid is a solution to the puzzle. | |
**/ | |
boolean verify(int[] grid, int[] puzzle) { | |
if (grid == null) { | |
return false; | |
} | |
// Check that all squares have a single digit, and | |
// no filled square in the puzzle was changed in the solution. | |
for (int s : SQUARES) { | |
if ((NUM_DIGITS[grid[s]] != 1) || (NUM_DIGITS[puzzle[s]] == 1 && grid[s] != puzzle[s])) { | |
return false; | |
} | |
} | |
// Check that each unit is a permutation of digits | |
for (int[] u : ALL_UNITS) { | |
int unit_digits = 0; // All the digits in a unit. | |
for (int s : u) { | |
unit_digits |= grid[s]; | |
} | |
if (unit_digits != ALL_DIGITS) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/** | |
* Choose an unfilled square with the minimum number of possible values. | |
* * If all squares are filled, return -1 (which means the puzzle is complete). | |
**/ | |
int select_square(int[] grid) { | |
int square = -1; | |
int min = N + 1; | |
for (int s : SQUARES) { | |
int c = NUM_DIGITS[grid[s]]; | |
if (c == 2) { | |
return s; // Can't get fewer than 2 possible digits | |
} else if (c > 1 && c < min) { | |
square = s; | |
min = c; | |
} | |
} | |
return square; | |
} | |
/** | |
* fill grid[s] = d. If this leads to contradiction, return null. | |
**/ | |
int[] fill(int[] grid, int s, int d) { | |
if ((grid == null) || ((grid[s] & d) == 0)) { | |
return null; | |
} // d not possible for grid[s] | |
grid[s] = d; | |
for (int p : PEERS[s]) { | |
if (!eliminate(grid, p, d)) { // If we can't eliminate d from all peers of s, then fail | |
return null; | |
} | |
} | |
return grid; | |
} | |
/** | |
* Eliminate digit d as a possibility for grid[s]. | |
* * Run the 3 constraint propagation routines. | |
* * If constraint propagation detects a contradiction, return false. | |
**/ | |
boolean eliminate(int[] grid, int s, int d) { | |
if ((grid[s] & d) == 0) { | |
return true; | |
} // d already eliminated from grid[s] | |
grid[s] -= d; | |
return arc_consistent(grid, s) && dual_consistent(grid, s, d) && naked_pairs(grid, s); | |
} | |
//////////////////////////////// Constraint Propagation //////////////////////////////// | |
/** | |
* Check if square s is consistent: that is, it has multiple possible values, or it has | |
* * one possible value which we can consistently fill. | |
**/ | |
boolean arc_consistent(int[] grid, int s) { | |
int count = NUM_DIGITS[grid[s]]; | |
return count >= 2 || (count == 1 && (fill(grid, s, grid[s]) != null)); | |
} | |
/** | |
* After we eliminate d from possibilities for grid[s], check each unit of s | |
* * and make sure there is some position in the unit where d can go. | |
* * If there is only one possible place for d, fill it with d. | |
**/ | |
boolean dual_consistent(int[] grid, int s, int d) { | |
for (int[] u : UNITS[s]) { | |
int dPlaces = 0; // The number of possible places for d within unit u | |
int dplace = -1; // Try to find a place in the unit where d can go | |
for (int s2 : u) { | |
if ((grid[s2] & d) > 0) { // s2 is a possible place for d | |
dPlaces++; | |
if (dPlaces > 1) break; | |
dplace = s2; | |
} | |
} | |
if (dPlaces == 0 || (dPlaces == 1 && (fill(grid, dplace, d) == null))) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/** | |
* Look for two squares in a unit with the same two possible values, and no other values. | |
* * For example, if s and s2 both have the possible values 8|9, then we know that 8 and 9 | |
* * must go in those two squares. We don't know which is which, but we can eliminate | |
* * 8 and 9 from any other square s3 that is in the unit. | |
**/ | |
boolean naked_pairs(int[] grid, int s) { | |
if (!runNakedPairs) { | |
return true; | |
} | |
int val = grid[s]; | |
if (NUM_DIGITS[val] != 2) { | |
return true; | |
} // Doesn't apply | |
for (int s2 : PEERS[s]) { | |
if (grid[s2] == val) { | |
// s and s2 are a naked pair; find what unit(s) they share | |
for (int[] u : UNITS[s]) { | |
if (member(s2, u)) { | |
for (int s3 : u) { // s3 can't have either of the values in val (e.g. 8|9) | |
if (s3 != s && s3 != s2) { | |
int d = HIGHEST_DIGIT[val]; | |
int d2 = val - d; | |
return eliminate(grid, s3, d) && eliminate(grid, s3, d2); | |
} | |
} | |
} | |
} | |
} | |
} | |
return true; | |
} | |
//////////////////////////////// Input //////////////////////////////// | |
/** | |
* The method `readFile` reads one puzzle per file line and returns a List of puzzle grids. | |
**/ | |
List<int[]> readFile(String filename) throws IOException { | |
try (BufferedReader in = new BufferedReader(new FileReader(filename))) { | |
List<int[]> grids = new ArrayList<>(1000); | |
String gridstring; | |
while ((gridstring = in.readLine()) != null) { | |
grids.add(parseGrid(gridstring)); | |
if (reversePuzzle) { | |
grids.add(parseGrid(new StringBuilder(gridstring).reverse().toString())); | |
} | |
} | |
return grids; | |
} | |
} | |
/** | |
* Parse a gridstring into a puzzle grid: an int[] with values DIGITS[0-9] or ALL_DIGITS. | |
**/ | |
int[] parseGrid(String gridstring) { | |
int[] grid = new int[N * N]; | |
int s = 0; | |
for (int i = 0; i < gridstring.length(); ++i) { | |
char c = gridstring.charAt(i); | |
if ('1' <= c && c <= '9') { | |
grid[s++] = DIGITS[c - '1']; // A single-bit set to represent a digit | |
} else if (c == '0' || c == '.') { | |
grid[s++] = ALL_DIGITS; // Any digit is possible | |
} | |
} | |
assert s == N * N; | |
return grid; | |
} | |
/** | |
* Initialize a grid from a puzzle. | |
* * First initialize every square in the new grid to ALL_DIGITS, meaning any value is possible. | |
* * Then, call `fill` on the puzzle's filled squares to initiate constraint propagation. | |
**/ | |
int[] initialize(int[] puzzle) { | |
int[] grid = new int[N * N]; | |
Arrays.fill(grid, ALL_DIGITS); | |
for (int s : SQUARES) { | |
if (puzzle[s] != ALL_DIGITS) { | |
fill(grid, s, puzzle[s]); | |
} | |
} | |
return grid; | |
} | |
//////////////////////////////// Output and Tests //////////////////////////////// | |
boolean headerPrinted = false; | |
/** | |
* Print stats on puzzles solved, average time, frequency, threads used, and name. | |
**/ | |
void printStats(int nGrids, long startTime, String name) { | |
double usecs = (System.nanoTime() - startTime) / 1000.; | |
String line = String.format("%7d %6.1f %7.3f %7d %10.1f %s", | |
nGrids, usecs / nGrids, 1000 * nGrids / usecs, nThreads, backtracks * 1. / nGrids, name); | |
synchronized (this) { // So that printing from different threads doesn't get garbled | |
if (!headerPrinted) { | |
System.out.println("Puzzles μsec KHz Threads Backtracks Name\n" | |
+ "======= ====== ======= ======= ========== ===="); | |
headerPrinted = true; | |
} | |
System.out.println(line); | |
backtracks = 0; | |
} | |
} | |
/** | |
* Print the original puzzle grid and the solution grid. | |
**/ | |
void printGrids(String name, int[] puzzle, int[] solution) { | |
String bar = "------+-------+------"; | |
String gap = " "; // Space between the puzzle grid and solution grid | |
if (solution == null) solution = new int[N * N]; | |
synchronized (this) { // So that printing from different threads doesn't get garbled | |
System.out.format("\n%-22s%s%s\n", name + ":", gap, | |
(verify(solution, puzzle) ? "Solution:" : "FAILED:")); | |
for (int r = 0; r < N; ++r) { | |
System.out.println(rowString(puzzle, r) + gap + rowString(solution, r)); | |
if (r == 2 || r == 5) System.out.println(bar + gap + " " + bar); | |
} | |
} | |
} | |
/** | |
* Return a String representing a row of this puzzle. | |
**/ | |
String rowString(int[] grid, int r) { | |
StringBuilder row = new StringBuilder(); | |
for (int s = r * 9; s < (r + 1) * 9; ++s) { | |
row.append((char) ((NUM_DIGITS[grid[s]] == 9) ? '.' : (NUM_DIGITS[grid[s]] != 1) ? '?' : | |
('1' + Integer.numberOfTrailingZeros(grid[s])))).append(s % 9 == 2 || s % 9 == 5 ? " | " : " "); | |
} | |
return row.toString(); | |
} | |
/** | |
* Unit Tests. Just getting started with these. | |
**/ | |
void runUnitTests() { | |
assert N == 9; | |
assert SQUARES.length == 81; | |
for (int s : SQUARES) { | |
assert UNITS[s].length == 3; | |
assert PEERS[s].length == 20; | |
} | |
assert Arrays.equals(PEERS[19], | |
new int[]{18, 20, 21, 22, 23, 24, 25, 26, 1, 10, 28, 37, 46, 55, 64, 73, 0, 2, 9, 11}); | |
assert Arrays.deepToString(UNITS[19]).equals( | |
"[[18, 19, 20, 21, 22, 23, 24, 25, 26], [1, 10, 19, 28, 37, 46, 55, 64, 73], [0, 1, 2, 9, 10, 11, 18, 19, 20]]"); | |
System.out.println("Unit tests pass."); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.teckhooi | |
import java.util | |
import java.util.concurrent.CountDownLatch | |
import scala.collection.mutable | |
import scala.collection.mutable.ArrayBuffer | |
import scala.io.Source | |
import scala.util.Using | |
object SudokuPort { | |
val usage: String = List( | |
"usage: java Sudoku -(no)[fghnprstuv] | -[RT]<number> | <filename> ...", | |
"E.g., -v turns verify flag on, -nov turns it off. -R and -T require a number. The options:\n", | |
" -f(ile) Print summary stats for each file (default on)", | |
" -g(rid) Print each puzzle grid and solution grid (default off)", | |
" -h(elp) Print this usage message", | |
" -n(aked) Run naked pairs (default on)", | |
" -p(uzzle) Print summary stats for each puzzle (default off)", | |
" -r(everse) Solve the reverse of each puzzle as well as each puzzle itself (default off)", | |
" -s(earch) Run search (default on, but some puzzles can be solved with CSP methods alone)", | |
" -t(hread) Print summary stats for each thread (default off)", | |
" -u(nitTest)Run a suite of unit tests (default off)", | |
" -v(erify) Verify each solution is valid (default on)", | |
" -T<number> Concurrently run <number> threads (default 26)", | |
" -R<number> Repeat each puzzle <number> times (default 1)", | |
" <filename> Solve all puzzles in filename, which has one puzzle per line" | |
).mkString("\n") | |
val printFileStats = true // -f | |
val printGrid = false // -g | |
val runNakedPairs = false // -n | |
val printPuzzleStats = false // -p | |
val reversePuzzle = false // -r | |
val runSearch = true // -s | |
val printThreadStats = false // -t | |
val verifySolution = true // -v | |
val nThreads = 26 // -T | |
val repeat = 1 // -R | |
/** Parse command line args and solve puzzles in files. | |
*/ | |
def main(args: Array[String]): Unit = | |
mainWith( | |
SudokuPort( | |
printFileStats, | |
printGrid, | |
runNakedPairs, | |
printPuzzleStats, | |
reversePuzzle, | |
runSearch, | |
printThreadStats, | |
verifySolution, | |
nThreads, | |
repeat | |
), | |
args | |
) | |
def mainWith(s: SudokuPort, args: Array[String]): Unit = | |
args.foreach(arg => | |
if (!arg.startsWith("-")) { | |
s.solveFile(arg) | |
} else { | |
val value = !arg.startsWith("-no") | |
arg.charAt(if (value) 1 else 3) match { | |
case 'f' => | |
s.printFileStats = value | |
case 'g' => | |
s.printGrid = value | |
case 'h' => | |
println(usage) | |
case 'n' => | |
s.runNakedPairs = value | |
case 'p' => | |
s.printPuzzleStats = value | |
case 'r' => | |
s.reversePuzzle = value | |
case 's' => | |
s.runSearch = value | |
case 't' => | |
s.printThreadStats = value | |
/* | |
case 'u' => | |
sudokuPort.runUnitTests() | |
*/ | |
case 'v' => | |
s.verifySolution = value | |
case 'T' => | |
s.nThreads = Integer.parseInt(arg.substring(2)) | |
case 'R' => | |
s.repeat = Integer.parseInt(arg.substring(2)) | |
case _ => | |
println(s"Unrecognized option: $arg\n$usage") | |
} | |
} | |
) | |
} | |
class SudokuPort( | |
var printFileStats: Boolean, | |
var printGrid: Boolean, | |
var runNakedPairs: Boolean, | |
var printPuzzleStats: Boolean, | |
var reversePuzzle: Boolean, | |
var runSearch: Boolean, | |
var printThreadStats: Boolean, | |
var verifySolution: Boolean, | |
var nThreads: Int, | |
var repeat: Int | |
) { | |
var backtracks = 0 // count total backtracks | |
def runUnitTests(): Unit = { | |
assert(N == 9) | |
assert(SQUARES.length == 81) | |
for (s <- SQUARES) { | |
assert(UNITS(s).length == 3) | |
assert(PEERS(s).length == 20) | |
} | |
assert( | |
PEERS(19).sameElements( | |
Array[Int](18, 20, 21, 22, 23, 24, 25, 26, 1, 10, 28, 37, 46, 55, 64, 73, 0, 2, 9, 11) | |
) | |
) | |
assert( | |
UNITS(19).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") == | |
"[[18, 19, 20, 21, 22, 23, 24, 25, 26], [1, 10, 19, 28, 37, 46, 55, 64, 73], [0, 1, 2, 9, 10, 11, 18, 19, 20]]" | |
) | |
println("Unit tests pass.") | |
} | |
//////////////////////////////// Handling Lists of Puzzles //////////////////////////////// | |
/** Solve all the puzzles in a file. Report timing statistics. | |
*/ | |
//////////////////////////////// Handling Lists of Puzzles //////////////////////////////// | |
def solveFile(filename: String): Unit = { | |
val grids: Array[Array[Int]] = readFile(filename) | |
val startTime: Long = System.nanoTime() | |
if (nThreads == 1) solveList(grids) | |
else solveListThreaded(grids, nThreads) | |
if (printFileStats) printStats(grids.length * repeat, startTime, filename) | |
} | |
/** Solve a list of puzzles in a single thread. * repeat -R<number> times print each puzzle's | |
* stats if -p print grid if -g verify if -v. | |
*/ | |
def solveList(grids: Array[Array[Int]]): Unit = { | |
val gridpool: Array[Array[Int]] = Array.ofDim(N * N, N * N) | |
grids.zipWithIndex.foreach((grid, ndx) => | |
val puzzle: Array[Int] = Array.copyOf(grid, grid.length) | |
(0 until repeat).foreach { g => | |
val startTime: Long = if (printPuzzleStats) System.nanoTime() else 0 | |
var solution: Array[Int] = initialize(grid) // All the real work is | |
if (runSearch) solution = search(solution, gridpool, 0) // on these 2 lines. | |
if (printPuzzleStats) { | |
printStats(1, startTime, rowString(grid, 0).replaceAll("[ |]", "")) | |
} | |
if (g == 0 && (printGrid || (verifySolution && !verify(solution, grid)))) { | |
printGrids("Puzzle " + (ndx + 1), grid, solution) | |
} | |
} | |
) | |
} | |
/** Break a list of puzzles into nThreads sublists and solve each sublist in a separate thread. | |
*/ | |
def solveListThreaded(grids: Array[Array[Int]], nThreads: Int): Unit = { | |
val nGrids: Int = grids.length | |
val latch: CountDownLatch = CountDownLatch(nThreads) | |
val size: Int = nGrids / nThreads | |
(0 until nThreads).foreach { c => | |
val end = if (c == nThreads - 1) nGrids else (c + 1) * size | |
val sublist: Array[Array[Int]] = grids.slice(c * size, end) | |
new Thread(() => { | |
val startTime: Long = System.nanoTime | |
solveList(sublist) | |
latch.countDown() | |
if (printThreadStats) { | |
printStats(repeat * sublist.length, startTime, c.toString) | |
} | |
}).start() | |
} | |
latch.await() // Wait for all threads to finish | |
} | |
//////////////////////////////// Utility functions //////////////////////////////// | |
/** Return an array of all squares in the intersection of these rows and cols | |
*/ | |
def cross(rows: Array[Int], cols: Array[Int]): Array[Int] = { | |
val result: Array[Int] = Array.ofDim(rows.length * cols.length) | |
var i = 0 | |
for { | |
r <- rows | |
c <- cols | |
} { | |
result(i) = N * r + c | |
i = i + 1 | |
} | |
result | |
} | |
/** Return true iff item is an element of array, or of array[0:end]. | |
*/ | |
def member(item: Int, array: Array[Int]): Boolean = | |
member(item, array, array.length) | |
def member(item: Int, array: Array[Int], end: Int): Boolean = | |
(0 until end).exists(i => array(i) == item) | |
//////////////////////////////// Constants //////////////////////////////// | |
val N: Int = 9 // Number of cells on a side of grid. | |
val DIGITS: Array[Int] = | |
Array(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8) | |
val ALL_DIGITS: Int = Integer.parseInt("111111111", 2) | |
val ROWS: Array[Int] = (0 until N).toArray | |
val COLS: Array[Int] = ROWS | |
val SQUARES: Array[Int] = (0 until N * N).toArray | |
val BLOCKS: Array[Array[Int]] = Array(Array(0, 1, 2), Array(3, 4, 5), Array(6, 7, 8)) | |
var ALL_UNITS: Array[Array[Int]] = Array.ofDim(3 * N, 9) | |
val UNITS: Array[Array[Array[Int]]] = Array.ofDim(N * N, 3, N) | |
val PEERS: Array[Array[Int]] = Array.ofDim(N * N, 20) | |
val NUM_DIGITS: Array[Int] = Array.ofDim(ALL_DIGITS + 1) | |
val HIGHEST_DIGIT: Array[Int] = Array.ofDim(ALL_DIGITS + 1) | |
var i = 0 | |
for (r <- ROWS) { | |
ALL_UNITS(i) = cross(Array(r), COLS) | |
i = i + 1 | |
} | |
for (c <- COLS) { | |
ALL_UNITS(i) = cross(ROWS, Array(c)) | |
i = i + 1 | |
} | |
for { | |
rb <- BLOCKS | |
cb <- BLOCKS | |
} { | |
ALL_UNITS(i) = cross(rb, cb) | |
i = i + 1 | |
} | |
// Initialize each UNITS[s] to be an array of the 3 units for square s. | |
for (s <- SQUARES) | |
var i = 0 | |
for (u <- ALL_UNITS) | |
if (member(s, u)) { | |
UNITS(s)(i) = u | |
i = i + 1 | |
} | |
// Initialize each PEERS[s] to be an array of the 20 squares that are peers of square s. | |
for (s <- SQUARES) { | |
var i = 0 | |
for (u <- UNITS(s)) | |
for (s2 <- u) | |
if (s2 != s && !member(s2, PEERS(s), i)) { | |
PEERS(s)(i) = s2 | |
i = i + 1 | |
} | |
} | |
// Initialize NUM_DIGITS[var] to be the number of 1 bits in the bitset var | |
// and HIGHEST_DIGIT[var] to the highest bit set in the bitset var | |
(0 to ALL_DIGITS).foreach { i => | |
NUM_DIGITS(i) = Integer.bitCount(i) | |
HIGHEST_DIGIT(i) = Integer.highestOneBit(i) | |
} | |
//////////////////////////////// Search algorithm //////////////////////////////// | |
/** Search for a solution to grid. If there is an unfilled square, select one * and try--that is, | |
* search recursively--every possible digit for the square. | |
*/ | |
def search(grid: Array[Int], gridpool: Array[Array[Int]], level: Int): Array[Int] = | |
if (grid == null) null | |
else { | |
val s = select_square(grid) | |
if (s == -1) grid // No squares to select means we are done! | |
else { | |
val nullArr: Array[Int] = null | |
DIGITS.foldLeft(nullArr) { (r, d) => | |
if (r != null) r | |
// For each possible digit d that could fill square s, try it | |
else if ((d & grid(s)) > 0) { | |
// Copy grid's contents into gridpool[level], and use that at the next level | |
System.arraycopy(grid, 0, gridpool(level), 0, grid.length) | |
val result = search(fill(gridpool(level), s, d), gridpool, level + 1) | |
if (result == null) backtracks += 1 | |
result | |
} else null | |
} | |
} | |
} | |
/** Verify that grid is a solution to the puzzle. | |
*/ | |
def verify(grid: Array[Int], puzzle: Array[Int]): Boolean = | |
if (grid == null) false | |
else { | |
// Check that all squares have a single digit, and | |
// no filled square in the puzzle was changed in the solution. | |
SQUARES.forall(s => | |
NUM_DIGITS(grid(s)) == 1 || (NUM_DIGITS(puzzle(s)) == 1 && grid(s) == puzzle(s)) | |
) && | |
ALL_UNITS.forall { u => | |
var unit_digits = 0 // All the digits in a unit. | |
for (s <- u) | |
unit_digits |= grid(s) | |
unit_digits == ALL_DIGITS | |
} | |
} | |
/** Choose an unfilled square with the minimum number of possible values. * If all squares are | |
* filled, return -1 (which means the puzzle is complete). | |
*/ | |
def select_square(grid: Array[Int]): Int = | |
SQUARES | |
.foldLeft((-1, N + 1)) { case ((square, min), s) => | |
val c = NUM_DIGITS(grid(s)) | |
if (c == 2) (s, 2) // Can't get fewer than 2 possible digits | |
else if (c > 1 && c < min) (s, c) | |
else (square, min) | |
} | |
._1 | |
/** fill grid[s] = d. If this leads to contradiction, return null. | |
*/ | |
def fill(grid: Array[Int], s: Int, d: Int): Array[Int] = | |
if ((grid == null) || ((grid(s) & d) == 0)) null // d not possible for grid[s] | |
else { | |
grid(s) = d | |
PEERS(s).foldLeft(grid)((arr, p) => | |
if (arr == null) null | |
else if (!eliminate(arr, p, d)) null | |
else arr | |
) | |
} | |
/** Eliminate digit d as a possibility for grid[s]. * Run the 3 constraint propagation routines. * | |
* If constraint propagation detects a contradiction, return false. | |
*/ | |
def eliminate(grid: Array[Int], s: Int, d: Int): Boolean = | |
if ((grid(s) & d) == 0) true // d already eliminated from grid[s] | |
else { | |
grid(s) -= d | |
arc_consistent(grid, s) && dual_consistent(grid, s, d) && naked_pairs(grid, s) | |
} | |
//////////////////////////////// Constraint Propagation //////////////////////////////// | |
/** Check if square s is consistent: that is, it has multiple possible values, or it has one | |
* possible value which we can consistently fill. | |
*/ | |
def arc_consistent(grid: Array[Int], s: Int): Boolean = { | |
val count = NUM_DIGITS(grid(s)) | |
count >= 2 || (count == 1 && (fill(grid, s, grid(s)) != null)) | |
} | |
/** After we eliminate d from possibilities for grid[s], check each unit of s and make sure there | |
* is some position in the unit where d can go. If there is only one possible place for d, fill | |
* it with d. | |
*/ | |
def dual_consistent(grid: Array[Int], s: Int, d: Int): Boolean = | |
UNITS(s) | |
.foldLeft(true) { (r, u) => | |
if (!r) r | |
else { | |
var dplace = -1 // Try to find a place in the unit where d can go | |
val dPlaces = u.foldLeft(0) { | |
(r2, s2) => // The number of possible places for d within unit u | |
if (r2 > 1) r2 | |
else if ((grid(s2) & d) > 0) { // s2 is a possible place for d | |
dplace = s2 | |
r2 + 1 | |
} else r2 | |
} | |
!(dPlaces == 0 || (dPlaces == 1 && (fill(grid, dplace, d) == null))) | |
} | |
} | |
/** Look for two squares in a unit with the same two possible values, and no other values. For | |
* example, if s and s2 both have the possible values 8|9, then we know that 8 and 9 must go in | |
* those two squares. We don't know which is which, but we can eliminate 8 and 9 from any other | |
* square s3 that is in the unit. | |
*/ | |
def naked_pairs(grid: Array[Int], s: Int): Boolean = | |
if (!runNakedPairs) true | |
else { | |
val i = grid(s) | |
// Doesn't apply | |
if (NUM_DIGITS(i) != 2) true | |
else { | |
PEERS(s).foldLeft(true)((b, s2) => | |
if (!b) false | |
else if (grid(s2) == i) { | |
// s and s2 are a naked pair find what unit(s) they share | |
UNITS(s).foldLeft(true)((b2, u) => | |
if (!b2) false | |
else if (member(s2, u)) { | |
u.foldLeft(true)((b3, s3) => // s3 can't have either of the values in var (e.g. 8|9) | |
if (!b3) false | |
else if (s3 != s && s3 != s2) { | |
val d = HIGHEST_DIGIT(i) | |
val d2 = i - d | |
eliminate(grid, s3, d) && eliminate(grid, s3, d2) | |
} else true | |
) | |
} else true | |
) | |
} else true | |
) | |
} | |
} | |
//////////////////////////////// Input //////////////////////////////// | |
/** The method `readFile` reads one puzzle per file line and returns a List of puzzle grids. | |
*/ | |
def readFile(filename: String): Array[Array[Int]] = | |
Using(Source.fromFile(filename)) { in => | |
val grids: ArrayBuffer[Array[Int]] = ArrayBuffer() | |
in.getLines() | |
.foreach(grindString => | |
grids += parseGrid(grindString) | |
if (reversePuzzle) { | |
grids += parseGrid(StringBuilder(grindString).toString().reverse) | |
} | |
) | |
grids.toArray | |
}.get | |
/** Parse a gridstring into a puzzle grid: an int[] with values DIGITS[0-9] or ALL_DIGITS. | |
*/ | |
def parseGrid(gridString: String): Array[Int] = { | |
val grid: Array[Int] = Array.ofDim(N * N) | |
gridString.zipWithIndex.foreach { (c, s) => | |
if ('1' <= c && c <= '9') { | |
grid(s) = DIGITS(c - '1') // A single-bit set to represent a digit | |
} else if (c == '0' || c == '.') { | |
grid(s) = ALL_DIGITS // Any digit is possible | |
} | |
} | |
grid | |
} | |
/** Initialize a grid from a puzzle. * First initialize every square in the new grid to | |
* ALL_DIGITS, meaning any value is possible. * Then, call `fill` on the puzzle's filled squares | |
* to initiate constraint propagation. | |
*/ | |
def initialize(puzzle: Array[Int]): Array[Int] = { | |
val grid: Array[Int] = Array.fill(N * N)(ALL_DIGITS) | |
for (s <- SQUARES) | |
if (puzzle(s) != ALL_DIGITS) { | |
fill(grid, s, puzzle(s)) | |
} | |
grid | |
} | |
//////////////////////////////// Output and Tests //////////////////////////////// | |
var headerPrinted = false | |
/** Print stats on puzzles solved, average time, frequency, threads used, and name. | |
*/ | |
def printStats(nGrids: Int, startTime: Long, name: String): Unit = { | |
val usecs: Double = (System.nanoTime() - startTime) / 1000f | |
val line: String = String.format( | |
"%7d %6.1f %7.3f %7d %10.1f %s", | |
nGrids, | |
usecs / nGrids, | |
1000 * nGrids / usecs, | |
nThreads, | |
backtracks * 1d / nGrids, | |
name | |
) | |
synchronized { | |
if (!headerPrinted) { | |
println( | |
"Puzzles μsec KHz Threads Backtracks Name\n" | |
+ "======= ====== ======= ======= ========== ====" | |
) | |
headerPrinted = true | |
} | |
println(line) | |
backtracks = 0 | |
} | |
} | |
/** Print the original puzzle grid and the solution grid. | |
*/ | |
def printGrids(name: String, puzzle: Array[Int], solution: Array[Int]): Unit = { | |
val bar = "------+-------+------" | |
val gap = " " // Space between the puzzle grid and solution grid | |
val _solution: Array[Int] = if (solution == null) Array.ofDim(N * N) else solution | |
synchronized { | |
System.out.format( | |
"\n%-22s%s%s\n", | |
name + ":", | |
gap, | |
if (verify(_solution, puzzle)) "Solution:" else "FAILED:" | |
) | |
(0 until N).foreach { r => | |
println(rowString(puzzle, r) + gap + rowString(_solution, r)) | |
if (r == 2 || r == 5) println(bar + gap + " " + bar) | |
} | |
} | |
} | |
/** Return a String representing a row of this puzzle. | |
*/ | |
def rowString(grid: Array[Int], r: Int): String = { | |
var row = "" | |
((r * 9) until (r + 1) * 9).foreach { s => | |
row = row + (if (NUM_DIGITS(grid(s)) == 9) '.' | |
else if (NUM_DIGITS(grid(s)) != 1) '?' | |
else ('1' + Integer.numberOfTrailingZeros(grid(s))).toChar).toString | |
row = row + (if (s % 9 == 2 || s % 9 == 5) " | " else " ") | |
} | |
row | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment