Skip to content

Instantly share code, notes, and snippets.

@DRMacIver
Created November 21, 2022 15:31
Show Gist options
  • Save DRMacIver/32ac88797c42d1a485ee56ff1caff1cd to your computer and use it in GitHub Desktop.
Save DRMacIver/32ac88797c42d1a485ee56ff1caff1cd to your computer and use it in GitHub Desktop.
import attr
def solve(initial_assignments, restriction_function):
"""
Given variables 0...n, such that variable i is allowed to
take on values in initial_assignments[i], find an assignment
of those variables which satisfies the restrictions enforced
by `restriction_function`. Return this as a list of assigned
values of length n.
If no such assignment is possible, raises Unsatisfiable.
The way this works is that restriction_function is passed a
list of partial assignments [(i, v)] indicating some subset
of the variables that are assigned to some value. It then
returns a list (or other iterable) of restrictions of the
form [(j, restriction)]. restriction is either a collection of
values that j is allowed to be assigned to, or a Blacklist
instance containing a collection of values that j may *not*
be assigned to. The same index is allowed to appear in the
output of restriction_function multiple times, and all the
provided restrictions will apply.
In addition, restriction_function may raise InvalidAssignment
to indicate that this assignment is not allowed. It is not
required to raise this if the assignment contains some violations
of restrictions it returns - this will be handled automatically.
restriction_function should have the property that it is
*consistent* in the sense that enlarging the assignment should
only produce stronger restrictions. In particular if some
subset of assignment is claimed to be invalid, this may not
be able to find that assignment. As long as this consistency
property is achieved, this function is guaranteed to return
a result if one exists.
"""
solver = Solver(initial_assignments, restriction_function)
solver.solve()
assignment = solver.assignment()
assert len(assignment) == len(initial_assignments)
for i, (j, _) in enumerate(assignment):
assert i == j
return [v for _, v in solver.assignment()]
class InvalidAssignment(Exception):
"""Exception raised when given an assignment that
violates some constraint, either directly or by
implication."""
pass
class Unsatisfiable(Exception):
"""Top level exception raises when no solution
exists to an assignment problem."""
pass
@attr.s()
class Blacklist:
"""A constraint indicating that none of `values`
are permitted."""
values = attr.ib()
class Solver:
def __init__(self, initial_values, restriction_function):
self.__canon_map = {}
self.initial_values = tuple(map(self.__canonicalise, initial_values))
self.current_values = list(self.initial_values)
self.restriction_function = restriction_function
self.__rollbacks = []
self.__implications = [{} for _ in self.current_values]
def begin(self):
"""begin a new commit, so that the next rollback() call returns
self.current_values to its present state."""
self.__rollbacks.append({})
def rollback(self):
"""Restore self.current_values to the state it was in when begin
was last called."""
values = self.__rollbacks.pop()
for i, v in values.items():
self.current_values[i] = v
def __setitem__(self, i, v):
"""self.current_values[i] = v, but to be rolled back if appropriate."""
v = self.__canonicalise(v)
if self.__rollbacks:
record = self.__rollbacks[-1]
record.setdefault(i, self.current_values[i])
self.current_values[i] = v
def __getitem__(self, i):
return self.current_values[i]
def __len__(self):
return len(self.current_values)
def implications(self, i, v):
"""Returns a list of restrictions that must hold true if variable i
were to be set to value v."""
return self.__implications[i].get(v, ())
def restrict(self, i, values):
"""Restrict variable `i` to belong to `values`. If `values` is a blacklist
then `i` will be restricted to not belonging to it.
This will propagate the restriction to any known implications of this
assignment.
"""
stack = [(i, values)]
while stack:
i, values = stack.pop()
old = self[i]
assert len(old) > 0
if isinstance(values, Blacklist):
values = old - values.values
values = self.__canonicalise(values)
self[i] &= values
new = self[i]
if len(new) == 0:
raise InvalidAssignment()
if len(new) == 1 and len(old) > 1:
stack.extend(self.implications(i, *new))
def propagate(self):
"""Adjusts the set of allowable values for variables to take
based on the current full assignment."""
prev = None
while True:
current = self.assignment()
if current == prev:
break
prev = current
for i, restriction in self.restriction_function(current):
self.restrict(i, restriction)
def refine_implications(self):
"""Updates the implication graph to take into account all
known information about allowable values. NB can only be
called at the top level with no branching."""
assert not self.__rollbacks
changed = True
while changed:
changed = False
for i in range(len(self)):
for v in self.current_values[i]:
prev = self.implications(i, v)
self.begin()
result = None
try:
self.restrict(i, {v})
self.propagate()
result = tuple([(j, self.current_values[j]) for j in self.__dirty() if j != i])
except InvalidAssignment:
pass
self.rollback()
if result is None:
try:
self.restrict(i, Blacklist({v}))
except InvalidAssignment:
raise Unsatisfiable()
else:
if result != prev:
changed = True
self.__implications[i][v] = result
def solve(self):
"""Does back tracking until it either raises Unsatisfiable or
self.assignment is full."""
try:
self.propagate()
except InvalidAssignment:
raise Unsatisfiable()
self.refine_implications()
def score(iv):
i, v = iv
return (
len(self[i]),
-sum(len(self[j]) - len(u) for j, u in self.implications(i, v)),
i, v
)
assignments = [(i, v) for i in self.variables() for v in self[i]]
assignments.sort(key=score)
stack = []
start = 0
while start < len(assignments):
assert len(self.__rollbacks) == len(stack)
i, v = assignments[start]
if len(self[i]) == 1 or v not in self[i]:
start += 1
else:
self.begin()
stack.append(start)
try:
# First try to perform the assignment.
self.restrict(i, {v})
self.propagate()
start += 1
except InvalidAssignment:
# Now we're in rollback mode.
while stack:
assert len(self.__rollbacks) == len(stack)
self.rollback()
start = stack.pop()
i, v = assignments[start]
try:
self.restrict(i, Blacklist({v}))
self.propagate()
start += 1
break
except InvalidAssignment:
pass
else:
# We've now unwound the entire stack and
# neither assignments[0] nor its inverse
# are allowed. Therefore this is unsatisfiable!
raise Unsatisfiable()
# Mostly internal functions
def variables(self):
return range(len(self))
def assignment(self):
return tuple([(i, *self.current_values[i]) for i in self.variables() if len(self.current_values[i]) == 1])
def __dirty(self):
if not self.__rollbacks:
return ()
return tuple(sorted(self.__rollbacks[-1].keys()))
def __canonicalise(self, s):
s = frozenset(s)
return self.__canon_map.setdefault(s, s)
from blackboxsolver import solve, Blacklist, InvalidAssignment, Unsatisfiable
from hypothesis import given, strategies as st, note
import pytest
def test_no_restrictions_only_single_values():
assert solve([{1}, {0}], lambda *_: []) == [1, 0]
def test_restrictions_all_in_function():
values = {0, 1, 2, 3, 4, 5}
n = 10
assert solve([values for _ in range(n)], lambda *_: [(i, {3}) for i in range(n)]) == [3] * n
def test_can_find_unique_assignment():
n = 5
values = frozenset(range(n))
def restriction(assignment):
assigned_variables = {i for i, _ in assignment}
used_values = frozenset({v for i, v in assignment})
if len(used_values) < len(assigned_variables):
raise InvalidAssignment()
return [(i, Blacklist(used_values)) for i in range(n) if i not in assigned_variables]
assert sorted(solve([values for _ in range(n)], restriction)) == list(range(n))
def test_can_find_increasing_sequence():
n = 10
values = frozenset(range(n))
def restriction(assignment):
for i, v in assignment:
if i + 1 < n:
yield [i + 1, range(v + 1, n)]
assert solve([values for _ in range(n)], restriction) == list(range(n))
def test_can_find_increasing_sequence_with_backtracking():
n = 10
values = frozenset(range(n + 1))
def restriction(assignment):
for i, v in assignment:
if i + 1 < n:
yield [i + 1, range(v + 1, n + 1)]
if len(assignment) == n:
yield (0, Blacklist({0}))
assert solve([values for _ in range(n)], restriction) == list(range(1, n + 1))
def test_triggers_backtracking():
n = range(5)
values = {0, 1, 2}
def restriction(assignment):
if len(assignment) == n and len({v for _, v in assignment}) < len(values):
raise Unsatisfiable()
@st.composite
def simple_assignment_problem(draw):
n = draw(st.integers(1, 10))
values = draw(st.lists(st.integers(0, 10), unique=True, min_size=1))
initial_values = [
draw(st.frozensets(st.sampled_from(values), min_size=1))
for _ in range(n)
]
implications = [
draw(st.dictionaries(
st.sampled_from(values),
st.lists(st.tuples(st.integers(0, n - 1), st.frozensets(st.sampled_from(values))))
))
for _ in initial_values
]
def restriction(assignment):
for i, v in assignment:
yield from implications[i].get(v, ())
restriction.implications = implications
restriction.__name__ = 'implied_by(%r)' % (implications,)
restriction.__qualname__ = restriction.__name__
return (initial_values, restriction)
def test_raises_unsatisfiable_when_exhausted():
n = 3
values = {0, 1, 2}
def reject(assignment):
if len(assignment) == n:
raise InvalidAssignment()
return ()
with pytest.raises(Unsatisfiable):
solve([values] * n, reject)
def test_can_exhaustively_enumerate():
n = 3
values = {0, 1, 2}
def reject_most(assignment):
if len(assignment) == n and sorted(assignment) != [(0, 2), (1, 1), (2, 0)]:
print(sorted(assignment))
raise InvalidAssignment()
return ()
assert solve([values] * n, reject_most) == [2, 1, 0]
@given(simple_assignment_problem())
def test_can_solve_simple_assignment_problem(prob):
values, f = prob
try:
result = solve(values, f)
except Unsatisfiable:
return
assert len(result) == len(values)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment