Skip to content

Instantly share code, notes, and snippets.

@sshark
Created October 27, 2024 07:16
Show Gist options
  • Save sshark/3144d8e1a16ff5a652697e8ec6617129 to your computer and use it in GitHub Desktop.
Save sshark/3144d8e1a16ff5a652697e8ec6617129 to your computer and use it in GitHub Desktop.
Sudoku Runtime Performance Java vs Scala
package pnorvig.ipynb;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.stream.IntStream;
//////////////////////////////// Solve Sudoku Puzzles ////////////////////////////////
//////////////////////////////// @author Peter Norvig ////////////////////////////////
/**
* There are two representations of puzzles that we will use:
* * 1. A gridstring is 81 chars, with characters '0' or '.' for blank and '1' to '9' for digits.
* * 2. A puzzle grid is an int[81] with a digit d (1-9) represented by the integer (1 << (d - 1));
* * that is, a bit pattern that has a single 1 bit representing the digit.
* * A blank is represented by the OR of all the digits 1-9, meaning that any digit is possible.
* * While solving the puzzle, some of these digits are eliminated, leaving fewer possibilities.
* * The puzzle is solved when every square has only a single possibility.
* *
* * Search for a solution with `search`:
* * - Fill an empty square with a guessed digit and do constraint propagation.
* * - If the guess is consistent, search deeper; if not, try a different guess for the square.
* * - If all guesses fail, back up to the previous level.
* * - In selecting an empty square, we pick one that has the minimum number of possible digits.
* * - To be able to back up, we need to keep the grid from the previous recursive level.
* * But we only need to keep one grid for each level, so to save garbage collection,
* * we pre-allocate one grid per level (there are 81 levels) in a `gridpool`.
* * Do constraint propagation with `arcConsistent`, `dualConsistent`, and `nakedPairs`.
**/
public class Sudoku {
//////////////////////////////// main; command line options //////////////////////////////
static final String usage = String.join("\n",
"usage: java Sudoku -(no)[fghnprstuv] | -[RT]<number> | <filename> ...",
"E.g., -v turns verify flag on, -nov turns it off. -R and -T require a number. The options:\n",
" -f(ile) Print summary stats for each file (default on)",
" -g(rid) Print each puzzle grid and solution grid (default off)",
" -h(elp) Print this usage message",
" -n(aked) Run naked pairs (default on)",
" -p(uzzle) Print summary stats for each puzzle (default off)",
" -r(everse) Solve the reverse of each puzzle as well as each puzzle itself (default off)",
" -s(earch) Run search (default on, but some puzzles can be solved with CSP methods alone)",
" -t(hread) Print summary stats for each thread (default off)",
" -u(nitTest)Run a suite of unit tests (default off)",
" -v(erify) Verify each solution is valid (default on)",
" -T<number> Concurrently run <number> threads (default 26)",
" -R<number> Repeat each puzzle <number> times (default 1)",
" <filename> Solve all puzzles in filename, which has one puzzle per line");
boolean printFileStats = true; // -f
boolean printGrid = false; // -g
boolean runNakedPairs = false; // -n
boolean printPuzzleStats = false; // -p
boolean reversePuzzle = false; // -r
boolean runSearch = true; // -s
boolean printThreadStats = false; // -t
boolean verifySolution = true; // -v
int nThreads = 26; // -T
int repeat = 1; // -R
int backtracks = 0; // count total backtracks
/**
* Parse command line args and solve puzzles in files.
**/
public static void main(String[] args) throws IOException {
mainWith(new Sudoku(), args);
}
public static void mainWith(Sudoku s, String[] args) throws IOException {
for (String arg : args) {
if (!arg.startsWith("-")) {
s.solveFile(arg);
} else {
boolean value = !arg.startsWith("-no");
switch (arg.charAt(value ? 1 : 3)) {
case 'f':
s.printFileStats = value;
break;
case 'g':
s.printGrid = value;
break;
case 'h':
System.out.println(usage);
break;
case 'n':
s.runNakedPairs = value;
break;
case 'p':
s.printPuzzleStats = value;
break;
case 'r':
s.reversePuzzle = value;
break;
case 's':
s.runSearch = value;
break;
case 't':
s.printThreadStats = value;
break;
case 'u':
s.runUnitTests();
break;
case 'v':
s.verifySolution = value;
break;
case 'T':
s.nThreads = Integer.parseInt(arg.substring(2));
break;
case 'R':
s.repeat = Integer.parseInt(arg.substring(2));
break;
default:
System.out.println("Unrecognized option: " + arg + "\n" + usage);
}
}
}
}
//////////////////////////////// Handling Lists of Puzzles ////////////////////////////////
/**
* Solve all the puzzles in a file. Report timing statistics.
**/
public void solveFile(String filename) throws IOException {
List<int[]> grids = readFile(filename);
long startFileTime = System.nanoTime();
if (nThreads == 1) {
solveList(grids);
} else {
solveListThreaded(grids, nThreads);
}
if (printFileStats) printStats(grids.size() * repeat, startFileTime, filename);
}
/**
* Solve a list of puzzles in a single thread.
* * repeat -R<number> times; print each puzzle's stats if -p; print grid if -g; verify if -v.
**/
void solveList(List<int[]> grids) {
int[] puzzle = new int[N * N]; // Used to save a copy of the original grid
int[][] gridpool = new int[N * N][N * N]; // Reuse grids during the search
for (int g = 0; g < grids.size(); ++g) {
int[] grid = grids.get(g);
System.arraycopy(grid, 0, puzzle, 0, grid.length);
for (int i = 0; i < repeat; ++i) {
long startTime = printPuzzleStats ? System.nanoTime() : 0;
int[] solution = initialize(grid); // All the real work is
if (runSearch) solution = search(solution, gridpool, 0); // on these 2 lines.
if (printPuzzleStats) {
printStats(1, startTime, rowString(grid, 0).replaceAll("[ |]", ""));
}
if (i == 0 && (printGrid || (verifySolution && !verify(solution, puzzle)))) {
printGrids("Puzzle " + (g + 1), grid, solution);
}
}
}
}
/**
* Break a list of puzzles into nThreads sublists and solve each sublist in a separate thread.
**/
void solveListThreaded(List<int[]> grids, int nThreads) {
try {
int nGrids = grids.size();
final CountDownLatch latch = new CountDownLatch(nThreads);
int size = nGrids / nThreads;
for (int c = 0; c < nThreads; ++c) {
int end = c == nThreads - 1 ? nGrids : (c + 1) * size;
final List<int[]> sublist = grids.subList(c * size, end);
new Thread(() -> {
final long startTime = System.nanoTime();
solveList(sublist);
latch.countDown();
if (printThreadStats) {
printStats(repeat * sublist.size(), startTime, Thread.currentThread().getName());
}
}).start();
}
latch.await(); // Wait for all threads to finish
} catch (InterruptedException e) {
System.err.println("And you may ask yourself, 'Well, how did I get here?'");
}
}
//////////////////////////////// Utility functions ////////////////////////////////
/**
* Return an array of all squares in the intersection of these rows and cols
**/
int[] cross(int[] rows, int[] cols) {
int[] result = new int[rows.length * cols.length];
int i = 0;
for (int r : rows) {
for (int c : cols) {
result[i++] = N * r + c;
}
}
return result;
}
/**
* Return true iff item is an element of array, or of array[0:end].
**/
boolean member(int item, int[] array) {
return member(item, array, array.length);
}
boolean member(int item, int[] array, int end) {
for (int i = 0; i < end; ++i) {
if (array[i] == item) {
return true;
}
}
return false;
}
//////////////////////////////// Constants ////////////////////////////////
final int N = 9; // Number of cells on a side of grid.
final int[] DIGITS = {1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8};
final int ALL_DIGITS = Integer.parseInt("111111111", 2);
final int[] ROWS = IntStream.range(0, N).toArray();
final int[] COLS = ROWS;
final int[] SQUARES = IntStream.range(0, N * N).toArray();
final int[][] BLOCKS = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}};
final int[][] ALL_UNITS = new int[3 * N][];
final int[][][] UNITS = new int[N * N][3][N];
final int[][] PEERS = new int[N * N][20];
final int[] NUM_DIGITS = new int[ALL_DIGITS + 1];
final int[] HIGHEST_DIGIT = new int[ALL_DIGITS + 1];
{
// Initialize ALL_UNITS to be an array of the 27 units: rows, columns, and blocks
int i = 0;
for (int r : ROWS) {
ALL_UNITS[i++] = cross(new int[]{r}, COLS);
}
for (int c : COLS) {
ALL_UNITS[i++] = cross(ROWS, new int[]{c});
}
for (int[] rb : BLOCKS) {
for (int[] cb : BLOCKS) {
ALL_UNITS[i++] = cross(rb, cb);
}
}
// Initialize each UNITS[s] to be an array of the 3 units for square s.
for (int s : SQUARES) {
i = 0;
for (int[] u : ALL_UNITS) {
if (member(s, u)) UNITS[s][i++] = u;
}
}
// Initialize each PEERS[s] to be an array of the 20 squares that are peers of square s.
for (int s : SQUARES) {
i = 0;
for (int[] u : UNITS[s]) {
for (int s2 : u) {
if (s2 != s && !member(s2, PEERS[s], i)) {
PEERS[s][i++] = s2;
}
}
}
}
// Initialize NUM_DIGITS[val] to be the number of 1 bits in the bitset val
// and HIGHEST_DIGIT[val] to the highest bit set in the bitset val
for (int val = 0; val <= ALL_DIGITS; val++) {
NUM_DIGITS[val] = Integer.bitCount(val);
HIGHEST_DIGIT[val] = Integer.highestOneBit(val);
}
}
//////////////////////////////// Search algorithm ////////////////////////////////
/**
* Search for a solution to grid. If there is an unfilled square, select one
* * and try--that is, search recursively--every possible digit for the square.
**/
int[] search(int[] grid, int[][] gridpool, int level) {
if (grid == null) {
return null;
}
int s = select_square(grid);
if (s == -1) {
return grid; // No squares to select means we are done!
}
for (int d : DIGITS) {
// For each possible digit d that could fill square s, try it
if ((d & grid[s]) > 0) {
// Copy grid's contents into gridpool[level], and use that at the next level
System.arraycopy(grid, 0, gridpool[level], 0, grid.length);
int[] result = search(fill(gridpool[level], s, d), gridpool, level + 1);
if (result != null) {
return result;
}
backtracks += 1;
}
}
return null;
}
/**
* Verify that grid is a solution to the puzzle.
**/
boolean verify(int[] grid, int[] puzzle) {
if (grid == null) {
return false;
}
// Check that all squares have a single digit, and
// no filled square in the puzzle was changed in the solution.
for (int s : SQUARES) {
if ((NUM_DIGITS[grid[s]] != 1) || (NUM_DIGITS[puzzle[s]] == 1 && grid[s] != puzzle[s])) {
return false;
}
}
// Check that each unit is a permutation of digits
for (int[] u : ALL_UNITS) {
int unit_digits = 0; // All the digits in a unit.
for (int s : u) {
unit_digits |= grid[s];
}
if (unit_digits != ALL_DIGITS) {
return false;
}
}
return true;
}
/**
* Choose an unfilled square with the minimum number of possible values.
* * If all squares are filled, return -1 (which means the puzzle is complete).
**/
int select_square(int[] grid) {
int square = -1;
int min = N + 1;
for (int s : SQUARES) {
int c = NUM_DIGITS[grid[s]];
if (c == 2) {
return s; // Can't get fewer than 2 possible digits
} else if (c > 1 && c < min) {
square = s;
min = c;
}
}
return square;
}
/**
* fill grid[s] = d. If this leads to contradiction, return null.
**/
int[] fill(int[] grid, int s, int d) {
if ((grid == null) || ((grid[s] & d) == 0)) {
return null;
} // d not possible for grid[s]
grid[s] = d;
for (int p : PEERS[s]) {
if (!eliminate(grid, p, d)) { // If we can't eliminate d from all peers of s, then fail
return null;
}
}
return grid;
}
/**
* Eliminate digit d as a possibility for grid[s].
* * Run the 3 constraint propagation routines.
* * If constraint propagation detects a contradiction, return false.
**/
boolean eliminate(int[] grid, int s, int d) {
if ((grid[s] & d) == 0) {
return true;
} // d already eliminated from grid[s]
grid[s] -= d;
return arc_consistent(grid, s) && dual_consistent(grid, s, d) && naked_pairs(grid, s);
}
//////////////////////////////// Constraint Propagation ////////////////////////////////
/**
* Check if square s is consistent: that is, it has multiple possible values, or it has
* * one possible value which we can consistently fill.
**/
boolean arc_consistent(int[] grid, int s) {
int count = NUM_DIGITS[grid[s]];
return count >= 2 || (count == 1 && (fill(grid, s, grid[s]) != null));
}
/**
* After we eliminate d from possibilities for grid[s], check each unit of s
* * and make sure there is some position in the unit where d can go.
* * If there is only one possible place for d, fill it with d.
**/
boolean dual_consistent(int[] grid, int s, int d) {
for (int[] u : UNITS[s]) {
int dPlaces = 0; // The number of possible places for d within unit u
int dplace = -1; // Try to find a place in the unit where d can go
for (int s2 : u) {
if ((grid[s2] & d) > 0) { // s2 is a possible place for d
dPlaces++;
if (dPlaces > 1) break;
dplace = s2;
}
}
if (dPlaces == 0 || (dPlaces == 1 && (fill(grid, dplace, d) == null))) {
return false;
}
}
return true;
}
/**
* Look for two squares in a unit with the same two possible values, and no other values.
* * For example, if s and s2 both have the possible values 8|9, then we know that 8 and 9
* * must go in those two squares. We don't know which is which, but we can eliminate
* * 8 and 9 from any other square s3 that is in the unit.
**/
boolean naked_pairs(int[] grid, int s) {
if (!runNakedPairs) {
return true;
}
int val = grid[s];
if (NUM_DIGITS[val] != 2) {
return true;
} // Doesn't apply
for (int s2 : PEERS[s]) {
if (grid[s2] == val) {
// s and s2 are a naked pair; find what unit(s) they share
for (int[] u : UNITS[s]) {
if (member(s2, u)) {
for (int s3 : u) { // s3 can't have either of the values in val (e.g. 8|9)
if (s3 != s && s3 != s2) {
int d = HIGHEST_DIGIT[val];
int d2 = val - d;
return eliminate(grid, s3, d) && eliminate(grid, s3, d2);
}
}
}
}
}
}
return true;
}
//////////////////////////////// Input ////////////////////////////////
/**
* The method `readFile` reads one puzzle per file line and returns a List of puzzle grids.
**/
List<int[]> readFile(String filename) throws IOException {
try (BufferedReader in = new BufferedReader(new FileReader(filename))) {
List<int[]> grids = new ArrayList<>(1000);
String gridstring;
while ((gridstring = in.readLine()) != null) {
grids.add(parseGrid(gridstring));
if (reversePuzzle) {
grids.add(parseGrid(new StringBuilder(gridstring).reverse().toString()));
}
}
return grids;
}
}
/**
* Parse a gridstring into a puzzle grid: an int[] with values DIGITS[0-9] or ALL_DIGITS.
**/
int[] parseGrid(String gridstring) {
int[] grid = new int[N * N];
int s = 0;
for (int i = 0; i < gridstring.length(); ++i) {
char c = gridstring.charAt(i);
if ('1' <= c && c <= '9') {
grid[s++] = DIGITS[c - '1']; // A single-bit set to represent a digit
} else if (c == '0' || c == '.') {
grid[s++] = ALL_DIGITS; // Any digit is possible
}
}
assert s == N * N;
return grid;
}
/**
* Initialize a grid from a puzzle.
* * First initialize every square in the new grid to ALL_DIGITS, meaning any value is possible.
* * Then, call `fill` on the puzzle's filled squares to initiate constraint propagation.
**/
int[] initialize(int[] puzzle) {
int[] grid = new int[N * N];
Arrays.fill(grid, ALL_DIGITS);
for (int s : SQUARES) {
if (puzzle[s] != ALL_DIGITS) {
fill(grid, s, puzzle[s]);
}
}
return grid;
}
//////////////////////////////// Output and Tests ////////////////////////////////
boolean headerPrinted = false;
/**
* Print stats on puzzles solved, average time, frequency, threads used, and name.
**/
void printStats(int nGrids, long startTime, String name) {
double usecs = (System.nanoTime() - startTime) / 1000.;
String line = String.format("%7d %6.1f %7.3f %7d %10.1f %s",
nGrids, usecs / nGrids, 1000 * nGrids / usecs, nThreads, backtracks * 1. / nGrids, name);
synchronized (this) { // So that printing from different threads doesn't get garbled
if (!headerPrinted) {
System.out.println("Puzzles μsec KHz Threads Backtracks Name\n"
+ "======= ====== ======= ======= ========== ====");
headerPrinted = true;
}
System.out.println(line);
backtracks = 0;
}
}
/**
* Print the original puzzle grid and the solution grid.
**/
void printGrids(String name, int[] puzzle, int[] solution) {
String bar = "------+-------+------";
String gap = " "; // Space between the puzzle grid and solution grid
if (solution == null) solution = new int[N * N];
synchronized (this) { // So that printing from different threads doesn't get garbled
System.out.format("\n%-22s%s%s\n", name + ":", gap,
(verify(solution, puzzle) ? "Solution:" : "FAILED:"));
for (int r = 0; r < N; ++r) {
System.out.println(rowString(puzzle, r) + gap + rowString(solution, r));
if (r == 2 || r == 5) System.out.println(bar + gap + " " + bar);
}
}
}
/**
* Return a String representing a row of this puzzle.
**/
String rowString(int[] grid, int r) {
StringBuilder row = new StringBuilder();
for (int s = r * 9; s < (r + 1) * 9; ++s) {
row.append((char) ((NUM_DIGITS[grid[s]] == 9) ? '.' : (NUM_DIGITS[grid[s]] != 1) ? '?' :
('1' + Integer.numberOfTrailingZeros(grid[s])))).append(s % 9 == 2 || s % 9 == 5 ? " | " : " ");
}
return row.toString();
}
/**
* Unit Tests. Just getting started with these.
**/
void runUnitTests() {
assert N == 9;
assert SQUARES.length == 81;
for (int s : SQUARES) {
assert UNITS[s].length == 3;
assert PEERS[s].length == 20;
}
assert Arrays.equals(PEERS[19],
new int[]{18, 20, 21, 22, 23, 24, 25, 26, 1, 10, 28, 37, 46, 55, 64, 73, 0, 2, 9, 11});
assert Arrays.deepToString(UNITS[19]).equals(
"[[18, 19, 20, 21, 22, 23, 24, 25, 26], [1, 10, 19, 28, 37, 46, 55, 64, 73], [0, 1, 2, 9, 10, 11, 18, 19, 20]]");
System.out.println("Unit tests pass.");
}
}
package org.teckhooi
import java.util
import java.util.concurrent.CountDownLatch
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.util.Using
object SudokuPort {
val usage: String = List(
"usage: java Sudoku -(no)[fghnprstuv] | -[RT]<number> | <filename> ...",
"E.g., -v turns verify flag on, -nov turns it off. -R and -T require a number. The options:\n",
" -f(ile) Print summary stats for each file (default on)",
" -g(rid) Print each puzzle grid and solution grid (default off)",
" -h(elp) Print this usage message",
" -n(aked) Run naked pairs (default on)",
" -p(uzzle) Print summary stats for each puzzle (default off)",
" -r(everse) Solve the reverse of each puzzle as well as each puzzle itself (default off)",
" -s(earch) Run search (default on, but some puzzles can be solved with CSP methods alone)",
" -t(hread) Print summary stats for each thread (default off)",
" -u(nitTest)Run a suite of unit tests (default off)",
" -v(erify) Verify each solution is valid (default on)",
" -T<number> Concurrently run <number> threads (default 26)",
" -R<number> Repeat each puzzle <number> times (default 1)",
" <filename> Solve all puzzles in filename, which has one puzzle per line"
).mkString("\n")
val printFileStats = true // -f
val printGrid = false // -g
val runNakedPairs = false // -n
val printPuzzleStats = false // -p
val reversePuzzle = false // -r
val runSearch = true // -s
val printThreadStats = false // -t
val verifySolution = true // -v
val nThreads = 26 // -T
val repeat = 1 // -R
/** Parse command line args and solve puzzles in files.
*/
def main(args: Array[String]): Unit =
mainWith(
SudokuPort(
printFileStats,
printGrid,
runNakedPairs,
printPuzzleStats,
reversePuzzle,
runSearch,
printThreadStats,
verifySolution,
nThreads,
repeat
),
args
)
def mainWith(s: SudokuPort, args: Array[String]): Unit =
args.foreach(arg =>
if (!arg.startsWith("-")) {
s.solveFile(arg)
} else {
val value = !arg.startsWith("-no")
arg.charAt(if (value) 1 else 3) match {
case 'f' =>
s.printFileStats = value
case 'g' =>
s.printGrid = value
case 'h' =>
println(usage)
case 'n' =>
s.runNakedPairs = value
case 'p' =>
s.printPuzzleStats = value
case 'r' =>
s.reversePuzzle = value
case 's' =>
s.runSearch = value
case 't' =>
s.printThreadStats = value
/*
case 'u' =>
sudokuPort.runUnitTests()
*/
case 'v' =>
s.verifySolution = value
case 'T' =>
s.nThreads = Integer.parseInt(arg.substring(2))
case 'R' =>
s.repeat = Integer.parseInt(arg.substring(2))
case _ =>
println(s"Unrecognized option: $arg\n$usage")
}
}
)
}
class SudokuPort(
var printFileStats: Boolean,
var printGrid: Boolean,
var runNakedPairs: Boolean,
var printPuzzleStats: Boolean,
var reversePuzzle: Boolean,
var runSearch: Boolean,
var printThreadStats: Boolean,
var verifySolution: Boolean,
var nThreads: Int,
var repeat: Int
) {
var backtracks = 0 // count total backtracks
def runUnitTests(): Unit = {
assert(N == 9)
assert(SQUARES.length == 81)
for (s <- SQUARES) {
assert(UNITS(s).length == 3)
assert(PEERS(s).length == 20)
}
assert(
PEERS(19).sameElements(
Array[Int](18, 20, 21, 22, 23, 24, 25, 26, 1, 10, 28, 37, 46, 55, 64, 73, 0, 2, 9, 11)
)
)
assert(
UNITS(19).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") ==
"[[18, 19, 20, 21, 22, 23, 24, 25, 26], [1, 10, 19, 28, 37, 46, 55, 64, 73], [0, 1, 2, 9, 10, 11, 18, 19, 20]]"
)
println("Unit tests pass.")
}
//////////////////////////////// Handling Lists of Puzzles ////////////////////////////////
/** Solve all the puzzles in a file. Report timing statistics.
*/
//////////////////////////////// Handling Lists of Puzzles ////////////////////////////////
def solveFile(filename: String): Unit = {
val grids: Array[Array[Int]] = readFile(filename)
val startTime: Long = System.nanoTime()
if (nThreads == 1) solveList(grids)
else solveListThreaded(grids, nThreads)
if (printFileStats) printStats(grids.length * repeat, startTime, filename)
}
/** Solve a list of puzzles in a single thread. * repeat -R<number> times print each puzzle's
* stats if -p print grid if -g verify if -v.
*/
def solveList(grids: Array[Array[Int]]): Unit = {
val gridpool: Array[Array[Int]] = Array.ofDim(N * N, N * N)
grids.zipWithIndex.foreach((grid, ndx) =>
val puzzle: Array[Int] = Array.copyOf(grid, grid.length)
(0 until repeat).foreach { g =>
val startTime: Long = if (printPuzzleStats) System.nanoTime() else 0
var solution: Array[Int] = initialize(grid) // All the real work is
if (runSearch) solution = search(solution, gridpool, 0) // on these 2 lines.
if (printPuzzleStats) {
printStats(1, startTime, rowString(grid, 0).replaceAll("[ |]", ""))
}
if (g == 0 && (printGrid || (verifySolution && !verify(solution, grid)))) {
printGrids("Puzzle " + (ndx + 1), grid, solution)
}
}
)
}
/** Break a list of puzzles into nThreads sublists and solve each sublist in a separate thread.
*/
def solveListThreaded(grids: Array[Array[Int]], nThreads: Int): Unit = {
val nGrids: Int = grids.length
val latch: CountDownLatch = CountDownLatch(nThreads)
val size: Int = nGrids / nThreads
(0 until nThreads).foreach { c =>
val end = if (c == nThreads - 1) nGrids else (c + 1) * size
val sublist: Array[Array[Int]] = grids.slice(c * size, end)
new Thread(() => {
val startTime: Long = System.nanoTime
solveList(sublist)
latch.countDown()
if (printThreadStats) {
printStats(repeat * sublist.length, startTime, c.toString)
}
}).start()
}
latch.await() // Wait for all threads to finish
}
//////////////////////////////// Utility functions ////////////////////////////////
/** Return an array of all squares in the intersection of these rows and cols
*/
def cross(rows: Array[Int], cols: Array[Int]): Array[Int] = {
val result: Array[Int] = Array.ofDim(rows.length * cols.length)
var i = 0
for {
r <- rows
c <- cols
} {
result(i) = N * r + c
i = i + 1
}
result
}
/** Return true iff item is an element of array, or of array[0:end].
*/
def member(item: Int, array: Array[Int]): Boolean =
member(item, array, array.length)
def member(item: Int, array: Array[Int], end: Int): Boolean =
(0 until end).exists(i => array(i) == item)
//////////////////////////////// Constants ////////////////////////////////
val N: Int = 9 // Number of cells on a side of grid.
val DIGITS: Array[Int] =
Array(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, 1 << 7, 1 << 8)
val ALL_DIGITS: Int = Integer.parseInt("111111111", 2)
val ROWS: Array[Int] = (0 until N).toArray
val COLS: Array[Int] = ROWS
val SQUARES: Array[Int] = (0 until N * N).toArray
val BLOCKS: Array[Array[Int]] = Array(Array(0, 1, 2), Array(3, 4, 5), Array(6, 7, 8))
var ALL_UNITS: Array[Array[Int]] = Array.ofDim(3 * N, 9)
val UNITS: Array[Array[Array[Int]]] = Array.ofDim(N * N, 3, N)
val PEERS: Array[Array[Int]] = Array.ofDim(N * N, 20)
val NUM_DIGITS: Array[Int] = Array.ofDim(ALL_DIGITS + 1)
val HIGHEST_DIGIT: Array[Int] = Array.ofDim(ALL_DIGITS + 1)
var i = 0
for (r <- ROWS) {
ALL_UNITS(i) = cross(Array(r), COLS)
i = i + 1
}
for (c <- COLS) {
ALL_UNITS(i) = cross(ROWS, Array(c))
i = i + 1
}
for {
rb <- BLOCKS
cb <- BLOCKS
} {
ALL_UNITS(i) = cross(rb, cb)
i = i + 1
}
// Initialize each UNITS[s] to be an array of the 3 units for square s.
for (s <- SQUARES)
var i = 0
for (u <- ALL_UNITS)
if (member(s, u)) {
UNITS(s)(i) = u
i = i + 1
}
// Initialize each PEERS[s] to be an array of the 20 squares that are peers of square s.
for (s <- SQUARES) {
var i = 0
for (u <- UNITS(s))
for (s2 <- u)
if (s2 != s && !member(s2, PEERS(s), i)) {
PEERS(s)(i) = s2
i = i + 1
}
}
// Initialize NUM_DIGITS[var] to be the number of 1 bits in the bitset var
// and HIGHEST_DIGIT[var] to the highest bit set in the bitset var
(0 to ALL_DIGITS).foreach { i =>
NUM_DIGITS(i) = Integer.bitCount(i)
HIGHEST_DIGIT(i) = Integer.highestOneBit(i)
}
//////////////////////////////// Search algorithm ////////////////////////////////
/** Search for a solution to grid. If there is an unfilled square, select one * and try--that is,
* search recursively--every possible digit for the square.
*/
def search(grid: Array[Int], gridpool: Array[Array[Int]], level: Int): Array[Int] =
if (grid == null) null
else {
val s = select_square(grid)
if (s == -1) grid // No squares to select means we are done!
else {
val nullArr: Array[Int] = null
DIGITS.foldLeft(nullArr) { (r, d) =>
if (r != null) r
// For each possible digit d that could fill square s, try it
else if ((d & grid(s)) > 0) {
// Copy grid's contents into gridpool[level], and use that at the next level
System.arraycopy(grid, 0, gridpool(level), 0, grid.length)
val result = search(fill(gridpool(level), s, d), gridpool, level + 1)
if (result == null) backtracks += 1
result
} else null
}
}
}
/** Verify that grid is a solution to the puzzle.
*/
def verify(grid: Array[Int], puzzle: Array[Int]): Boolean =
if (grid == null) false
else {
// Check that all squares have a single digit, and
// no filled square in the puzzle was changed in the solution.
SQUARES.forall(s =>
NUM_DIGITS(grid(s)) == 1 || (NUM_DIGITS(puzzle(s)) == 1 && grid(s) == puzzle(s))
) &&
ALL_UNITS.forall { u =>
var unit_digits = 0 // All the digits in a unit.
for (s <- u)
unit_digits |= grid(s)
unit_digits == ALL_DIGITS
}
}
/** Choose an unfilled square with the minimum number of possible values. * If all squares are
* filled, return -1 (which means the puzzle is complete).
*/
def select_square(grid: Array[Int]): Int =
SQUARES
.foldLeft((-1, N + 1)) { case ((square, min), s) =>
val c = NUM_DIGITS(grid(s))
if (c == 2) (s, 2) // Can't get fewer than 2 possible digits
else if (c > 1 && c < min) (s, c)
else (square, min)
}
._1
/** fill grid[s] = d. If this leads to contradiction, return null.
*/
def fill(grid: Array[Int], s: Int, d: Int): Array[Int] =
if ((grid == null) || ((grid(s) & d) == 0)) null // d not possible for grid[s]
else {
grid(s) = d
PEERS(s).foldLeft(grid)((arr, p) =>
if (arr == null) null
else if (!eliminate(arr, p, d)) null
else arr
)
}
/** Eliminate digit d as a possibility for grid[s]. * Run the 3 constraint propagation routines. *
* If constraint propagation detects a contradiction, return false.
*/
def eliminate(grid: Array[Int], s: Int, d: Int): Boolean =
if ((grid(s) & d) == 0) true // d already eliminated from grid[s]
else {
grid(s) -= d
arc_consistent(grid, s) && dual_consistent(grid, s, d) && naked_pairs(grid, s)
}
//////////////////////////////// Constraint Propagation ////////////////////////////////
/** Check if square s is consistent: that is, it has multiple possible values, or it has one
* possible value which we can consistently fill.
*/
def arc_consistent(grid: Array[Int], s: Int): Boolean = {
val count = NUM_DIGITS(grid(s))
count >= 2 || (count == 1 && (fill(grid, s, grid(s)) != null))
}
/** After we eliminate d from possibilities for grid[s], check each unit of s and make sure there
* is some position in the unit where d can go. If there is only one possible place for d, fill
* it with d.
*/
def dual_consistent(grid: Array[Int], s: Int, d: Int): Boolean =
UNITS(s)
.foldLeft(true) { (r, u) =>
if (!r) r
else {
var dplace = -1 // Try to find a place in the unit where d can go
val dPlaces = u.foldLeft(0) {
(r2, s2) => // The number of possible places for d within unit u
if (r2 > 1) r2
else if ((grid(s2) & d) > 0) { // s2 is a possible place for d
dplace = s2
r2 + 1
} else r2
}
!(dPlaces == 0 || (dPlaces == 1 && (fill(grid, dplace, d) == null)))
}
}
/** Look for two squares in a unit with the same two possible values, and no other values. For
* example, if s and s2 both have the possible values 8|9, then we know that 8 and 9 must go in
* those two squares. We don't know which is which, but we can eliminate 8 and 9 from any other
* square s3 that is in the unit.
*/
def naked_pairs(grid: Array[Int], s: Int): Boolean =
if (!runNakedPairs) true
else {
val i = grid(s)
// Doesn't apply
if (NUM_DIGITS(i) != 2) true
else {
PEERS(s).foldLeft(true)((b, s2) =>
if (!b) false
else if (grid(s2) == i) {
// s and s2 are a naked pair find what unit(s) they share
UNITS(s).foldLeft(true)((b2, u) =>
if (!b2) false
else if (member(s2, u)) {
u.foldLeft(true)((b3, s3) => // s3 can't have either of the values in var (e.g. 8|9)
if (!b3) false
else if (s3 != s && s3 != s2) {
val d = HIGHEST_DIGIT(i)
val d2 = i - d
eliminate(grid, s3, d) && eliminate(grid, s3, d2)
} else true
)
} else true
)
} else true
)
}
}
//////////////////////////////// Input ////////////////////////////////
/** The method `readFile` reads one puzzle per file line and returns a List of puzzle grids.
*/
def readFile(filename: String): Array[Array[Int]] =
Using(Source.fromFile(filename)) { in =>
val grids: ArrayBuffer[Array[Int]] = ArrayBuffer()
in.getLines()
.foreach(grindString =>
grids += parseGrid(grindString)
if (reversePuzzle) {
grids += parseGrid(StringBuilder(grindString).toString().reverse)
}
)
grids.toArray
}.get
/** Parse a gridstring into a puzzle grid: an int[] with values DIGITS[0-9] or ALL_DIGITS.
*/
def parseGrid(gridString: String): Array[Int] = {
val grid: Array[Int] = Array.ofDim(N * N)
gridString.zipWithIndex.foreach { (c, s) =>
if ('1' <= c && c <= '9') {
grid(s) = DIGITS(c - '1') // A single-bit set to represent a digit
} else if (c == '0' || c == '.') {
grid(s) = ALL_DIGITS // Any digit is possible
}
}
grid
}
/** Initialize a grid from a puzzle. * First initialize every square in the new grid to
* ALL_DIGITS, meaning any value is possible. * Then, call `fill` on the puzzle's filled squares
* to initiate constraint propagation.
*/
def initialize(puzzle: Array[Int]): Array[Int] = {
val grid: Array[Int] = Array.fill(N * N)(ALL_DIGITS)
for (s <- SQUARES)
if (puzzle(s) != ALL_DIGITS) {
fill(grid, s, puzzle(s))
}
grid
}
//////////////////////////////// Output and Tests ////////////////////////////////
var headerPrinted = false
/** Print stats on puzzles solved, average time, frequency, threads used, and name.
*/
def printStats(nGrids: Int, startTime: Long, name: String): Unit = {
val usecs: Double = (System.nanoTime() - startTime) / 1000f
val line: String = String.format(
"%7d %6.1f %7.3f %7d %10.1f %s",
nGrids,
usecs / nGrids,
1000 * nGrids / usecs,
nThreads,
backtracks * 1d / nGrids,
name
)
synchronized {
if (!headerPrinted) {
println(
"Puzzles μsec KHz Threads Backtracks Name\n"
+ "======= ====== ======= ======= ========== ===="
)
headerPrinted = true
}
println(line)
backtracks = 0
}
}
/** Print the original puzzle grid and the solution grid.
*/
def printGrids(name: String, puzzle: Array[Int], solution: Array[Int]): Unit = {
val bar = "------+-------+------"
val gap = " " // Space between the puzzle grid and solution grid
val _solution: Array[Int] = if (solution == null) Array.ofDim(N * N) else solution
synchronized {
System.out.format(
"\n%-22s%s%s\n",
name + ":",
gap,
if (verify(_solution, puzzle)) "Solution:" else "FAILED:"
)
(0 until N).foreach { r =>
println(rowString(puzzle, r) + gap + rowString(_solution, r))
if (r == 2 || r == 5) println(bar + gap + " " + bar)
}
}
}
/** Return a String representing a row of this puzzle.
*/
def rowString(grid: Array[Int], r: Int): String = {
var row = ""
((r * 9) until (r + 1) * 9).foreach { s =>
row = row + (if (NUM_DIGITS(grid(s)) == 9) '.'
else if (NUM_DIGITS(grid(s)) != 1) '?'
else ('1' + Integer.numberOfTrailingZeros(grid(s))).toChar).toString
row = row + (if (s % 9 == 2 || s % 9 == 5) " | " else " ")
}
row
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment