Skip to content

Instantly share code, notes, and snippets.

@abhijat
Created April 18, 2020 15:34
Show Gist options
  • Save abhijat/4ae62061eb761c1e1adb4b27c95efdc4 to your computer and use it in GitHub Desktop.
Save abhijat/4ae62061eb761c1e1adb4b27c95efdc4 to your computer and use it in GitHub Desktop.
from functools import reduce
from operator import mul
class Solution(object):
def __init__(self):
self.solution = [[None] * 3, [None] * 3, [None] * 3]
self._row = self._column = 0
def row_product(self, i):
row = self.solution[i]
return reduce(mul, [k for k in row if k is not None], 1), None in row
def column_product(self, i):
column = [self.solution[k][i] for k in range(len(self.solution))]
return reduce(mul, [k for k in column if k is not None], 1), None in column
def complete(self):
return all([None not in self.solution[i] for i in range(len(self.solution))])
def add(self, n):
if self._column == len(self.solution[0]):
self._column = 0
self._row += 1
self.solution[self._row][self._column] = n
self._column += 1
def pop(self):
if self._column == 0:
self._column = 2
self._row -= 1
else:
self._column -= 1
self.solution[self._row][self._column] = None
class Constraints(object):
def __init__(self, rows, columns):
self.columns = columns
self.rows = rows
self.factors = {
i: {i for i in range(1, 10) if r % i == 0}
for i, r in enumerate(self.rows)
}
def can_continue(self, solution: Solution):
for i in range(len(solution.solution)):
row_product, row_incomplete = solution.row_product(i)
if row_incomplete and row_product > self.rows[i]:
return False
if not row_incomplete and row_product != self.rows[i]:
return False
column_product, column_incomplete = solution.column_product(i)
if column_incomplete and column_product > self.columns[i]:
return False
if not column_incomplete and column_product != self.columns[i]:
return False
return True
def is_valid(self, solution: Solution):
return solution.complete() and all([
self.rows[i] == solution.row_product(i)[0]
and self.columns[i] == solution.column_product(i)[0]
for i in range(3)
])
def backtracking_search(constraints, solution, solution_space, index=0):
if constraints.is_valid(solution):
return solution
if index >= 9:
return None
if constraints.is_valid(solution):
return solution
if not constraints.can_continue(solution):
return None
for c in solution_space & constraints.factors[index // 3]:
solution.add(c)
_result = backtracking_search(constraints, solution, solution_space - {c}, index + 1)
if _result:
return _result
solution.pop()
if __name__ == '__main__':
result = backtracking_search(
Constraints(rows=[54, 120, 56], columns=[96, 180, 21]),
Solution(),
set(range(1, 10)),
)
if result:
print(result.solution)
@pranjalchaubey
Copy link

So cool!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment