Created
November 21, 2022 15:31
-
-
Save DRMacIver/32ac88797c42d1a485ee56ff1caff1cd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import attr | |
def solve(initial_assignments, restriction_function): | |
""" | |
Given variables 0...n, such that variable i is allowed to | |
take on values in initial_assignments[i], find an assignment | |
of those variables which satisfies the restrictions enforced | |
by `restriction_function`. Return this as a list of assigned | |
values of length n. | |
If no such assignment is possible, raises Unsatisfiable. | |
The way this works is that restriction_function is passed a | |
list of partial assignments [(i, v)] indicating some subset | |
of the variables that are assigned to some value. It then | |
returns a list (or other iterable) of restrictions of the | |
form [(j, restriction)]. restriction is either a collection of | |
values that j is allowed to be assigned to, or a Blacklist | |
instance containing a collection of values that j may *not* | |
be assigned to. The same index is allowed to appear in the | |
output of restriction_function multiple times, and all the | |
provided restrictions will apply. | |
In addition, restriction_function may raise InvalidAssignment | |
to indicate that this assignment is not allowed. It is not | |
required to raise this if the assignment contains some violations | |
of restrictions it returns - this will be handled automatically. | |
restriction_function should have the property that it is | |
*consistent* in the sense that enlarging the assignment should | |
only produce stronger restrictions. In particular if some | |
subset of assignment is claimed to be invalid, this may not | |
be able to find that assignment. As long as this consistency | |
property is achieved, this function is guaranteed to return | |
a result if one exists. | |
""" | |
solver = Solver(initial_assignments, restriction_function) | |
solver.solve() | |
assignment = solver.assignment() | |
assert len(assignment) == len(initial_assignments) | |
for i, (j, _) in enumerate(assignment): | |
assert i == j | |
return [v for _, v in solver.assignment()] | |
class InvalidAssignment(Exception): | |
"""Exception raised when given an assignment that | |
violates some constraint, either directly or by | |
implication.""" | |
pass | |
class Unsatisfiable(Exception): | |
"""Top level exception raises when no solution | |
exists to an assignment problem.""" | |
pass | |
@attr.s() | |
class Blacklist: | |
"""A constraint indicating that none of `values` | |
are permitted.""" | |
values = attr.ib() | |
class Solver: | |
def __init__(self, initial_values, restriction_function): | |
self.__canon_map = {} | |
self.initial_values = tuple(map(self.__canonicalise, initial_values)) | |
self.current_values = list(self.initial_values) | |
self.restriction_function = restriction_function | |
self.__rollbacks = [] | |
self.__implications = [{} for _ in self.current_values] | |
def begin(self): | |
"""begin a new commit, so that the next rollback() call returns | |
self.current_values to its present state.""" | |
self.__rollbacks.append({}) | |
def rollback(self): | |
"""Restore self.current_values to the state it was in when begin | |
was last called.""" | |
values = self.__rollbacks.pop() | |
for i, v in values.items(): | |
self.current_values[i] = v | |
def __setitem__(self, i, v): | |
"""self.current_values[i] = v, but to be rolled back if appropriate.""" | |
v = self.__canonicalise(v) | |
if self.__rollbacks: | |
record = self.__rollbacks[-1] | |
record.setdefault(i, self.current_values[i]) | |
self.current_values[i] = v | |
def __getitem__(self, i): | |
return self.current_values[i] | |
def __len__(self): | |
return len(self.current_values) | |
def implications(self, i, v): | |
"""Returns a list of restrictions that must hold true if variable i | |
were to be set to value v.""" | |
return self.__implications[i].get(v, ()) | |
def restrict(self, i, values): | |
"""Restrict variable `i` to belong to `values`. If `values` is a blacklist | |
then `i` will be restricted to not belonging to it. | |
This will propagate the restriction to any known implications of this | |
assignment. | |
""" | |
stack = [(i, values)] | |
while stack: | |
i, values = stack.pop() | |
old = self[i] | |
assert len(old) > 0 | |
if isinstance(values, Blacklist): | |
values = old - values.values | |
values = self.__canonicalise(values) | |
self[i] &= values | |
new = self[i] | |
if len(new) == 0: | |
raise InvalidAssignment() | |
if len(new) == 1 and len(old) > 1: | |
stack.extend(self.implications(i, *new)) | |
def propagate(self): | |
"""Adjusts the set of allowable values for variables to take | |
based on the current full assignment.""" | |
prev = None | |
while True: | |
current = self.assignment() | |
if current == prev: | |
break | |
prev = current | |
for i, restriction in self.restriction_function(current): | |
self.restrict(i, restriction) | |
def refine_implications(self): | |
"""Updates the implication graph to take into account all | |
known information about allowable values. NB can only be | |
called at the top level with no branching.""" | |
assert not self.__rollbacks | |
changed = True | |
while changed: | |
changed = False | |
for i in range(len(self)): | |
for v in self.current_values[i]: | |
prev = self.implications(i, v) | |
self.begin() | |
result = None | |
try: | |
self.restrict(i, {v}) | |
self.propagate() | |
result = tuple([(j, self.current_values[j]) for j in self.__dirty() if j != i]) | |
except InvalidAssignment: | |
pass | |
self.rollback() | |
if result is None: | |
try: | |
self.restrict(i, Blacklist({v})) | |
except InvalidAssignment: | |
raise Unsatisfiable() | |
else: | |
if result != prev: | |
changed = True | |
self.__implications[i][v] = result | |
def solve(self): | |
"""Does back tracking until it either raises Unsatisfiable or | |
self.assignment is full.""" | |
try: | |
self.propagate() | |
except InvalidAssignment: | |
raise Unsatisfiable() | |
self.refine_implications() | |
def score(iv): | |
i, v = iv | |
return ( | |
len(self[i]), | |
-sum(len(self[j]) - len(u) for j, u in self.implications(i, v)), | |
i, v | |
) | |
assignments = [(i, v) for i in self.variables() for v in self[i]] | |
assignments.sort(key=score) | |
stack = [] | |
start = 0 | |
while start < len(assignments): | |
assert len(self.__rollbacks) == len(stack) | |
i, v = assignments[start] | |
if len(self[i]) == 1 or v not in self[i]: | |
start += 1 | |
else: | |
self.begin() | |
stack.append(start) | |
try: | |
# First try to perform the assignment. | |
self.restrict(i, {v}) | |
self.propagate() | |
start += 1 | |
except InvalidAssignment: | |
# Now we're in rollback mode. | |
while stack: | |
assert len(self.__rollbacks) == len(stack) | |
self.rollback() | |
start = stack.pop() | |
i, v = assignments[start] | |
try: | |
self.restrict(i, Blacklist({v})) | |
self.propagate() | |
start += 1 | |
break | |
except InvalidAssignment: | |
pass | |
else: | |
# We've now unwound the entire stack and | |
# neither assignments[0] nor its inverse | |
# are allowed. Therefore this is unsatisfiable! | |
raise Unsatisfiable() | |
# Mostly internal functions | |
def variables(self): | |
return range(len(self)) | |
def assignment(self): | |
return tuple([(i, *self.current_values[i]) for i in self.variables() if len(self.current_values[i]) == 1]) | |
def __dirty(self): | |
if not self.__rollbacks: | |
return () | |
return tuple(sorted(self.__rollbacks[-1].keys())) | |
def __canonicalise(self, s): | |
s = frozenset(s) | |
return self.__canon_map.setdefault(s, s) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from blackboxsolver import solve, Blacklist, InvalidAssignment, Unsatisfiable | |
from hypothesis import given, strategies as st, note | |
import pytest | |
def test_no_restrictions_only_single_values(): | |
assert solve([{1}, {0}], lambda *_: []) == [1, 0] | |
def test_restrictions_all_in_function(): | |
values = {0, 1, 2, 3, 4, 5} | |
n = 10 | |
assert solve([values for _ in range(n)], lambda *_: [(i, {3}) for i in range(n)]) == [3] * n | |
def test_can_find_unique_assignment(): | |
n = 5 | |
values = frozenset(range(n)) | |
def restriction(assignment): | |
assigned_variables = {i for i, _ in assignment} | |
used_values = frozenset({v for i, v in assignment}) | |
if len(used_values) < len(assigned_variables): | |
raise InvalidAssignment() | |
return [(i, Blacklist(used_values)) for i in range(n) if i not in assigned_variables] | |
assert sorted(solve([values for _ in range(n)], restriction)) == list(range(n)) | |
def test_can_find_increasing_sequence(): | |
n = 10 | |
values = frozenset(range(n)) | |
def restriction(assignment): | |
for i, v in assignment: | |
if i + 1 < n: | |
yield [i + 1, range(v + 1, n)] | |
assert solve([values for _ in range(n)], restriction) == list(range(n)) | |
def test_can_find_increasing_sequence_with_backtracking(): | |
n = 10 | |
values = frozenset(range(n + 1)) | |
def restriction(assignment): | |
for i, v in assignment: | |
if i + 1 < n: | |
yield [i + 1, range(v + 1, n + 1)] | |
if len(assignment) == n: | |
yield (0, Blacklist({0})) | |
assert solve([values for _ in range(n)], restriction) == list(range(1, n + 1)) | |
def test_triggers_backtracking(): | |
n = range(5) | |
values = {0, 1, 2} | |
def restriction(assignment): | |
if len(assignment) == n and len({v for _, v in assignment}) < len(values): | |
raise Unsatisfiable() | |
@st.composite | |
def simple_assignment_problem(draw): | |
n = draw(st.integers(1, 10)) | |
values = draw(st.lists(st.integers(0, 10), unique=True, min_size=1)) | |
initial_values = [ | |
draw(st.frozensets(st.sampled_from(values), min_size=1)) | |
for _ in range(n) | |
] | |
implications = [ | |
draw(st.dictionaries( | |
st.sampled_from(values), | |
st.lists(st.tuples(st.integers(0, n - 1), st.frozensets(st.sampled_from(values)))) | |
)) | |
for _ in initial_values | |
] | |
def restriction(assignment): | |
for i, v in assignment: | |
yield from implications[i].get(v, ()) | |
restriction.implications = implications | |
restriction.__name__ = 'implied_by(%r)' % (implications,) | |
restriction.__qualname__ = restriction.__name__ | |
return (initial_values, restriction) | |
def test_raises_unsatisfiable_when_exhausted(): | |
n = 3 | |
values = {0, 1, 2} | |
def reject(assignment): | |
if len(assignment) == n: | |
raise InvalidAssignment() | |
return () | |
with pytest.raises(Unsatisfiable): | |
solve([values] * n, reject) | |
def test_can_exhaustively_enumerate(): | |
n = 3 | |
values = {0, 1, 2} | |
def reject_most(assignment): | |
if len(assignment) == n and sorted(assignment) != [(0, 2), (1, 1), (2, 0)]: | |
print(sorted(assignment)) | |
raise InvalidAssignment() | |
return () | |
assert solve([values] * n, reject_most) == [2, 1, 0] | |
@given(simple_assignment_problem()) | |
def test_can_solve_simple_assignment_problem(prob): | |
values, f = prob | |
try: | |
result = solve(values, f) | |
except Unsatisfiable: | |
return | |
assert len(result) == len(values) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment