Last active
March 23, 2023 19:47
-
-
Save deliciouslytyped/8318efa17f062b4c27e0890e450424ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
#from line_profiler_pycharm import profile | |
from collections import deque | |
from queue import PriorityQueue # minheap | |
from typing import * | |
import unittest | |
import timeit | |
# MyPy type aliases | |
Fringe = Union[Deque["Node"], List["Node"], PriorityQueue["Node"]] | |
Heuristic = Callable[["State"], int] | |
Op = Any # TODO | |
_State = Any | |
class Node: | |
def __init__(self, state, parent=None, step=None, path_cost=0, value=None): | |
self.parent: Node = parent | |
self.depth = parent.depth + 1 if parent else 0 | |
self.step = step | |
self.state: _State = state | |
self.path_cost = path_cost | |
self.value = value | |
self._path = (parent._path if parent and parent._path else []) + [self] | |
def __repr__(self): | |
return "<Csúcs: %s>" % (self.state, ) # Inner type may be iterable | |
def __eq__(self, other): | |
return isinstance(other, self.__class__) and self.state == other.state | |
def __lt__(self, other): | |
return isinstance(other, self.__class__) and self.state < other.state | |
def __hash__(self): | |
return hash(self.state) | |
def as_solution(self): | |
return [x.step for x in self._path] | |
def as_solution2(self): | |
return self._path | |
class BaseModel: | |
def __init__(self, start: _State = None): | |
self.start: Node = Node(start) if not isinstance(start, Node) else start # TODO I had a null check here, is it needed? | |
def is_goal(self, state: _State) -> bool: | |
raise NotImplementedError | |
def path_cost(self, c, state1, step, state2): # TODO use | |
raise NotImplementedError | |
def succ(self, state: _State) -> Iterable[Tuple[Op, _State]]: | |
raise NotImplementedError | |
# computed node value used in heuristics | |
def value(self, state: _State): | |
raise NotImplementedError | |
def children(self, node): | |
for (oper, succ) in self.succ(node.state): | |
if succ not in (x.state for x in node._path): # TODO ??? | |
yield Node(succ, node, oper, | |
self.path_cost(node.path_cost, node.state, oper, succ)) | |
# noinspection PyAbstractClass | |
class DefaultModel(BaseModel): | |
def path_cost(self, c, state1, step, state2): # TODO use | |
return c + 1 | |
class Search: | |
def __init__(self, model): | |
self.model: BaseModel = model | |
self.visited_nodes = None | |
self.total_nodes = None | |
self.extensions = None | |
self.reset_stats() | |
def reset_stats(self): | |
self.visited_nodes = 0 | |
self.total_nodes = 1 | |
self.extensions = 0 | |
def search(self, fringe: Fringe, heuristic: Heuristic = None): | |
raise NotImplementedError | |
@staticmethod | |
def pop(ds: Fringe): | |
match ds: | |
case deque(): | |
r = ds.popleft() | |
case list(): | |
r = ds.pop() | |
case PriorityQueue(): | |
_, r = ds.get_nowait() | |
case _: | |
raise NotImplementedError | |
return r | |
@staticmethod | |
def extend(ds: Fringe, items, heuristic: Heuristic = None): | |
match ds: | |
case PriorityQueue(): | |
for i in items: | |
ds.put_nowait((heuristic(i.state), i)) | |
case list() | deque(): | |
ds.extend(items) | |
case _: | |
raise NotImplementedError | |
def dfs(self, fringe: Fringe = None): | |
return self.search(fringe if fringe else [self.model.start]) | |
def bfs(self, fringe: Fringe = None): | |
return self.search(fringe if fringe else deque([self.model.start])) | |
def best_first(self, heuristic: Heuristic = None, fringe: Fringe = None): | |
q = PriorityQueue() | |
q.put_nowait((0, self.model.start)) | |
return self.search(fringe if fringe else q, heuristic) | |
def astar(self, h: Heuristic = None, fringe: Fringe = None): | |
def f(n: Node): | |
return n.path_cost + (h(n) if h else self.model.value(n.state)) | |
return self.best_first(f, fringe) | |
def print_stats(self): | |
print(f"visited: {self.visited_nodes}, total: {self.total_nodes}, extensions: {self.extensions}") | |
class Tree(Search): | |
def search(self, fringe: Fringe, heuristic: Heuristic = None): | |
while fringe and (n := self.pop(fringe)): | |
self.visited_nodes += 1 | |
if self.model.is_goal(n.state): | |
yield n | |
else: | |
self.extend(fringe, c := list(self.model.children(n)), heuristic) | |
self.extensions += 1 | |
self.total_nodes += len(c) | |
class Graph(Search): | |
def search(self, fringe: Fringe, heuristic: Heuristic = None): | |
closed = set() | |
while fringe and (n := self.pop(fringe)): | |
self.visited_nodes += 1 | |
if self.model.is_goal(n.state): | |
yield n | |
else: | |
c = list(self.model.children(n)) # TODO only list()-ed to be able to len() it below | |
new = set([x for x in c if x not in closed]) | |
#for e in new: | |
# self.model.print_state(e.state) | |
self.extend(fringe, new, heuristic) | |
closed.update(new) | |
self.extensions += 1 | |
self.total_nodes += len(c) # total calculated nodes | |
print() | |
class Numbers(DefaultModel): | |
State = Tuple[Tuple[int, int, int], int] | |
def __init__(self, start=((5,7,6),0)): | |
super().__init__(start) | |
def succ(self, state: State) -> Iterable[Tuple[Op, _State]]: | |
num, prevpos = state | |
disallowed = [(6, 6, 6), (6, 6, 7)] | |
def moddigit(n, i, v): | |
return n[:i] + (v,) + n[i+1:] | |
moves = list() | |
for i, d in enumerate(num): | |
if prevpos != i: | |
if (nv := (d + 1)) < 9: | |
ns = (f"+{i}", (nn := (moddigit(num, i, nv), i))) | |
if nn not in disallowed: | |
moves.append(ns) | |
if 0 <= (nv := (d - 1)): # note 0 is an allowed digit for the first position | |
ns = (f"-{i}", (nn := (moddigit(num, i, nv), i))) | |
if nn not in disallowed: | |
moves.append(ns) | |
return moves | |
def is_goal(self, state: State) -> bool: | |
return state[0] == (7, 7, 7) | |
class Jugs(DefaultModel): | |
State = Tuple[int, int, int] | |
def __init__(self, ke=(8, 0, 0)): | |
super().__init__(ke) | |
self.K1 = 8 | |
self.K2 = 5 | |
self.K3 = 3 | |
def is_goal(self, state): | |
return 4 in state | |
def succ(self, state): | |
k1, k2, k3 = state | |
steps = list() | |
if k1 > 0 and k2 < self.K2: | |
m = min([k1, self.K2-k2]) | |
steps.append(("k1-ből k2-be", (k1-m, k2+m, k3))) | |
if k1 > 0 and k3 < self.K3: | |
m = min([k1, self.K3 - k3]) | |
steps.append(("k1-ből k3-ba", (k1-m, k2, k3+m))) | |
if k2 > 0 and k3 < self.K3: | |
m = min([k2, self.K3 - k3]) | |
steps.append(("k2-ből k3-ba", (k1, k2-m, k3 + m))) | |
if k2 > 0 and k1 < self.K1: | |
m = min([k2, self.K1 - k1]) | |
steps.append(("k2-ből k1-be", (k1+m, k2 - m, k3))) | |
if k3 > 0 and k1 < self.K1: | |
m = min([k3, self.K1 - k1]) | |
steps.append(("k3-ből k1-be", (k1+m, k2, k3-m))) | |
if k3 > 0 and k2 < self.K2: | |
m = min([k3, self.K2 - k2]) | |
steps.append(("k3-ből k2-be", (k1, k2+m, k3 - m))) | |
return steps | |
def __str__(self): | |
return 'Kancsó:' + str(self.start) | |
def value(self, state: _State): # TODO | |
pass | |
class Fruit(DefaultModel): | |
State = Tuple[int, int, int] | |
def __init__(self, start=(13, 46, 59)): | |
super().__init__(start) | |
def is_goal(self, state: State) -> bool: | |
return len([x for x in state if x != 0]) == 1 | |
def succ(self, state: State) -> Iterable[Tuple[Op, State]]: | |
apple, pear, peach = state | |
steps = list() | |
if apple > 0 and pear > 0: | |
steps.append(("apple and pear for peach", (apple-1, pear-1, peach+2))) | |
if pear > 0 and peach > 0: | |
steps.append(("pear and peach for apple", (apple+2, pear-1, peach-1))) | |
if peach > 0 and apple > 0: | |
steps.append(("peach and apple for pear", (apple-1, pear+2, peach-1))) | |
return steps | |
def value(self, state: State): # TODO | |
return 118-max(state) | |
class Bishops(DefaultModel): | |
State = Tuple[int, Tuple[Tuple[Tuple[int, int]]]] | |
def __init__(self, start=None): | |
# see succ for explanations of the state space | |
super().__init__(start if start else | |
(1, (((0, 0), (1, 1), (2, 2), (3, 3)), | |
((-4, 4), (-3, 5), (-2, 6), (-1, 7))))) | |
self.pieces = 4 | |
self.next = {1: 2, 2: 1} | |
self.w = 4 | |
self.h = 5 | |
def is_goal(self, state: State) -> bool: | |
_, (pl1_2, pl2_2) = state | |
_, (pl1, pl2) = self.start.state | |
return pl1_2 == pl2 and pl2_2 == pl1 | |
# - states are represented as which player's turn it is, as well as x-y coordinates for the pieces of each player | |
# - we use the euclidean coordinate system of the diagonals, to make threat checks easy | |
# (thus with half-spacing, movement is thus over odd or even coordinates) | |
# - bounds checking is done by bounding from the bottom with an abs function, and from the top with a | |
# translated abs function | |
# - threat checking is by checking is after a move any of the coordinates diagonals of our pieces are the same | |
# as the diagonals of the other player | |
# The problem can also be factored into two subproblems because a bishop cant change the color of the grid squares | |
# it traverses - but we don't do this (yet? TODO) | |
#@profile | |
def succ(self, state: State) -> Iterable[Tuple[Op, State]]: | |
# constants for the bounds check geometry | |
assert self.h > self.w | |
skew = -(self.h - self.w) | |
diag = 2 * (self.h - 1) | |
pl, players = state | |
pl -= 1 | |
moves = list() | |
directions = ((2, 0), (-2, 0), (0, 2), (0, -2)) | |
allcoords = sum(players, tuple()) | |
for i, (_x, _y) in enumerate(players[pl]): # for each of our pieces x coordinates | |
othera, otherb = players[:pl], players[pl + 1:] | |
otherx, othery = zip(*sum(othera + otherb, tuple())) | |
for direction in directions: # for each possible movement direction | |
# for each possible movement range; | |
# we do a bounds check similar to ray tracing, because it's easier than | |
# trying to deduce the max range | |
for dist in range(1, 3): | |
x = _x + dist * direction[0] | |
# bounds check based on the geometry of the abs function | |
# (like drawing the rectangle tilted 45 deg) | |
if not abs(x) <= (y := _y + dist * direction[1]) <= -abs(x-skew)+diag+skew: | |
break | |
# collision checking | |
if (x, y) in allcoords: | |
continue | |
# if any of our new diagonals match any of the enemy diagonals | |
# they threaten each other, so skip this alternative | |
# TODO this should be safe by induction? | |
if x in otherx or y in othery: | |
continue | |
newcoords: Tuple[Tuple[int, int]] = players[pl][:i] + ((x, y),) + players[pl][i + 1:] | |
moves.append((f"[{_x},{_y}] to [{x},{y}]", | |
(self.next[pl+1], othera + (newcoords,) + otherb))) | |
return moves | |
def value(self, state: State): # TODO | |
return super().value(state) | |
def print_state(self, state: State): | |
pl, players = state | |
t = sum(([[0]*self.w] for _ in range(self.h)), []) | |
for i, coords in enumerate(players): | |
for x, y in coords: | |
# linear transform: rotate and scale with 45 deg rotation matrix | |
# 2 / sqrt(2) * [ cos(a) sin(a) ; -sin(a) cos(a) ] where a=45deg == [ 1 1 ; -1 1 ] | |
# to transform diagonal coordinate system to normal | |
# note which index/coord is row and which is column | |
t[(-x+y)//2][(x+y)//2] = i+1 | |
print(f"{pl} to move") | |
for lst in reversed(t): | |
print(" ".join(map(str, lst))) | |
class Queens(DefaultModel): | |
State = Tuple[int] | |
def __init__(self, start=None): | |
super().__init__(start if start else tuple()) | |
def is_goal(self, state: State) -> bool: | |
return len(state) == 8 | |
#TODO use permutations instead? | |
def succ(self, state: State) -> Iterable[Tuple[Op, _State]]: | |
coords = [(i - state[i], i + state[i]) for i in range(len(state))] | |
xs, ys = zip(*coords) if coords else ([],[]) | |
steps = list() | |
i = len(state) # Note 1-based | |
for j in range(8): | |
if i - j in xs or i + j in ys or j in state: | |
continue | |
steps.append((f"queen at {j}",state+(j,))) | |
return steps | |
def value(self, state: State): # TODO | |
pass | |
def print_state(self, state: State): | |
lines = [list(["_"]*8) for _ in range(8)] | |
for i,j in enumerate(state): | |
lines[j][i] = "Q" | |
for line in lines: | |
print(" ".join(line)) | |
print() | |
class Hanoi(DefaultModel): | |
State = Tuple[int, int, int] | |
def __init__(self, start=(1, 1, 1)): | |
super().__init__(start) | |
def is_goal(self, state: State) -> bool: | |
return state == (3, 3, 3) | |
def succ(self, state: State) -> Iterable[Tuple[Op, State]]: | |
moves = list() | |
# for i, rod in enumerate(state): | |
# if i == min([j for j, x in enumerate(state) if x == rod]): | |
# for l, z in enumerate(state): | |
# if i < l and z != rod: | |
# moves.append((f"move {i} from {rod} to {z}", tuple(z+1 if u == i else v for u, v in enumerate(state)))) | |
def find(x): | |
iter = ((i, y) for i, y in enumerate(state) if x == y) | |
try: | |
i, _ = next(iter) | |
return i | |
except (StopIteration, ValueError): | |
return -1 | |
f1, f2, f3 = find(1), find(2), find(3) | |
def f(a, b, c): | |
if -1 < a and (a < b or b == -1): | |
moves.append((f"{a} to {c}", state[0:a] + (c,) + state[a+1:])) | |
f(f1, f2, 2) | |
f(f1, f3, 3) | |
f(f2, f1, 1) | |
f(f2, f3, 3) | |
f(f3, f1, 1) | |
f(f2, f2, 2) | |
return moves | |
def value(self, state: State): | |
pass | |
class Test(unittest.TestCase): | |
@staticmethod | |
def test_soln_fruits(): | |
f = Fruit() | |
s = Graph(f) | |
r = list(s.bfs()) | |
for soln in r: | |
assert set(soln.state) == {118, 0} | |
s.print_stats() | |
s.reset_stats() | |
r = list(s.dfs()) | |
for soln in r: | |
assert set(soln.state) == {118, 0} | |
s.print_stats() | |
s = Tree(f) | |
soln = next(s.best_first(f.value)) | |
assert set(soln.state) == {118, 0} | |
s.print_stats() | |
@staticmethod | |
def test_soln_bishops(): | |
b = Bishops() | |
search = Graph(b) | |
# s: List[Node] = next(search.dfs()) | |
# for p in s.as_solution(): | |
# b.print_state(p.state) | |
# search.print_stats() | |
search.reset_stats() | |
s: List[Node] = next(search.bfs()) | |
for p in s.as_solution(): | |
b.print_state(p.state) | |
search.print_stats() | |
@staticmethod | |
def test_soln_queens(): | |
q = Queens() | |
s = Graph(q) | |
l = list(s.bfs()) | |
for e in l: | |
q.print_state(e.state) | |
print("\n") | |
s.print_stats() | |
@staticmethod | |
def test_soln_hanoi(): | |
q = Hanoi() | |
s = Graph(q) | |
l = list(s.bfs()) | |
print(l) | |
s.print_stats() | |
@staticmethod | |
def test_soln_numbers(): | |
n = Numbers() | |
s = Graph(n) | |
l = list(s.bfs()) | |
print(l) | |
s.print_stats() | |
for i in l: | |
print(i.as_solution()) | |
def test_soln_husbands(self): | |
h = Husbands() | |
s = Graph(h) | |
r = next(s.bfs()) | |
print(r.as_solution()) | |
print(r.as_solution2()) | |
s.print_stats() | |
s.reset_stats() | |
r = next(s.best_first(s.model.value)) | |
print(r.as_solution()) | |
print(r.as_solution2()) | |
s.print_stats() | |
s.reset_stats() | |
from itertools import product | |
class Husbands(DefaultModel): | |
State = Tuple[int, int, int, int, int, int, int] | |
def __init__(self, start=None): | |
super().__init__(start if start else (-1, -1, -1, -1, -1, -1, -1)) | |
def is_goal(self, state: State) -> bool: | |
# Technically we don't care where the boat ends up, | |
# though it's impossible for it to end up on the other side | |
return state[:-1] == (1, 1, 1, 1, 1, 1) | |
def succ(self, state: State) -> Iterable[Tuple[Op, State]]: | |
side = state[-1] | |
moves = list() | |
def move(state: Husbands.State, i: int, j: int, side: int) -> Tuple[Op, Husbands.State]: | |
def _move(_state, x, _side): | |
return _state[:x] + (_side,) + _state[x+1:] | |
basemove = _move(_move(state, i, side), 6, side) | |
return (f"move {i} to {side}", basemove) if j is None\ | |
else (f"move {i},{j} to {side}", _move(basemove, j, side)) | |
def valid(state: Tuple[Op, Husbands.State]) -> bool: | |
state = state[1] # TODO awkward | |
def implies(a, b): | |
return not a or b | |
return all(implies(state[i] != state[i+1], | |
state[i + 1] != state[j]) | |
for i in {1, 3, 5} | |
for j in ({1, 3, 5} - {i})) | |
#we use set() to remove duplicates because I'm lazy | |
# this can probably be done better | |
for i, j in set([(x, y) for x, y in product(range(6), [None, 0, 1, 2, 3, 4, 5, 6]) if x != y]): | |
if state[i] == side and (j is None or state[j] == side) and valid(m := move(state, i, j, -side)): | |
moves.append(m) | |
return moves | |
def value(self, state: State): | |
# TODO heuristic is probably bad because the problem solution is probably not monotonic | |
return 6 - len([x for i, x in enumerate(state) if x == 1 and i % 2 == 1]) | |
if __name__ == "__main__": | |
# print(timeit.timeit(lambda: Test().test_soln_fruits(), number=1)) | |
# print(timeit.timeit(lambda: Test().test_soln_bishops(), number=1)) | |
# print(timeit.timeit(lambda: Test().test_soln_queens(), number=1)) | |
#print(timeit.timeit(lambda: Test().test_soln_hanoi(), number=1)) | |
#print(timeit.timeit(lambda: Test().test_soln_numbers(), number=1)) | |
print(timeit.timeit(lambda: Test().test_soln_husbands(), number=1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment