Skip to content

Instantly share code, notes, and snippets.

@deliciouslytyped
Last active March 23, 2023 19:47
Show Gist options
  • Save deliciouslytyped/8318efa17f062b4c27e0890e450424ae to your computer and use it in GitHub Desktop.
Save deliciouslytyped/8318efa17f062b4c27e0890e450424ae to your computer and use it in GitHub Desktop.
import timeit
#from line_profiler_pycharm import profile
from collections import deque
from queue import PriorityQueue # minheap
from typing import *
import unittest
import timeit
# MyPy type aliases
Fringe = Union[Deque["Node"], List["Node"], PriorityQueue["Node"]]
Heuristic = Callable[["State"], int]
Op = Any # TODO
_State = Any
class Node:
def __init__(self, state, parent=None, step=None, path_cost=0, value=None):
self.parent: Node = parent
self.depth = parent.depth + 1 if parent else 0
self.step = step
self.state: _State = state
self.path_cost = path_cost
self.value = value
self._path = (parent._path if parent and parent._path else []) + [self]
def __repr__(self):
return "<Csúcs: %s>" % (self.state, ) # Inner type may be iterable
def __eq__(self, other):
return isinstance(other, self.__class__) and self.state == other.state
def __lt__(self, other):
return isinstance(other, self.__class__) and self.state < other.state
def __hash__(self):
return hash(self.state)
def as_solution(self):
return [x.step for x in self._path]
def as_solution2(self):
return self._path
class BaseModel:
def __init__(self, start: _State = None):
self.start: Node = Node(start) if not isinstance(start, Node) else start # TODO I had a null check here, is it needed?
def is_goal(self, state: _State) -> bool:
raise NotImplementedError
def path_cost(self, c, state1, step, state2): # TODO use
raise NotImplementedError
def succ(self, state: _State) -> Iterable[Tuple[Op, _State]]:
raise NotImplementedError
# computed node value used in heuristics
def value(self, state: _State):
raise NotImplementedError
def children(self, node):
for (oper, succ) in self.succ(node.state):
if succ not in (x.state for x in node._path): # TODO ???
yield Node(succ, node, oper,
self.path_cost(node.path_cost, node.state, oper, succ))
# noinspection PyAbstractClass
class DefaultModel(BaseModel):
def path_cost(self, c, state1, step, state2): # TODO use
return c + 1
class Search:
def __init__(self, model):
self.model: BaseModel = model
self.visited_nodes = None
self.total_nodes = None
self.extensions = None
self.reset_stats()
def reset_stats(self):
self.visited_nodes = 0
self.total_nodes = 1
self.extensions = 0
def search(self, fringe: Fringe, heuristic: Heuristic = None):
raise NotImplementedError
@staticmethod
def pop(ds: Fringe):
match ds:
case deque():
r = ds.popleft()
case list():
r = ds.pop()
case PriorityQueue():
_, r = ds.get_nowait()
case _:
raise NotImplementedError
return r
@staticmethod
def extend(ds: Fringe, items, heuristic: Heuristic = None):
match ds:
case PriorityQueue():
for i in items:
ds.put_nowait((heuristic(i.state), i))
case list() | deque():
ds.extend(items)
case _:
raise NotImplementedError
def dfs(self, fringe: Fringe = None):
return self.search(fringe if fringe else [self.model.start])
def bfs(self, fringe: Fringe = None):
return self.search(fringe if fringe else deque([self.model.start]))
def best_first(self, heuristic: Heuristic = None, fringe: Fringe = None):
q = PriorityQueue()
q.put_nowait((0, self.model.start))
return self.search(fringe if fringe else q, heuristic)
def astar(self, h: Heuristic = None, fringe: Fringe = None):
def f(n: Node):
return n.path_cost + (h(n) if h else self.model.value(n.state))
return self.best_first(f, fringe)
def print_stats(self):
print(f"visited: {self.visited_nodes}, total: {self.total_nodes}, extensions: {self.extensions}")
class Tree(Search):
def search(self, fringe: Fringe, heuristic: Heuristic = None):
while fringe and (n := self.pop(fringe)):
self.visited_nodes += 1
if self.model.is_goal(n.state):
yield n
else:
self.extend(fringe, c := list(self.model.children(n)), heuristic)
self.extensions += 1
self.total_nodes += len(c)
class Graph(Search):
def search(self, fringe: Fringe, heuristic: Heuristic = None):
closed = set()
while fringe and (n := self.pop(fringe)):
self.visited_nodes += 1
if self.model.is_goal(n.state):
yield n
else:
c = list(self.model.children(n)) # TODO only list()-ed to be able to len() it below
new = set([x for x in c if x not in closed])
#for e in new:
# self.model.print_state(e.state)
self.extend(fringe, new, heuristic)
closed.update(new)
self.extensions += 1
self.total_nodes += len(c) # total calculated nodes
print()
class Numbers(DefaultModel):
State = Tuple[Tuple[int, int, int], int]
def __init__(self, start=((5,7,6),0)):
super().__init__(start)
def succ(self, state: State) -> Iterable[Tuple[Op, _State]]:
num, prevpos = state
disallowed = [(6, 6, 6), (6, 6, 7)]
def moddigit(n, i, v):
return n[:i] + (v,) + n[i+1:]
moves = list()
for i, d in enumerate(num):
if prevpos != i:
if (nv := (d + 1)) < 9:
ns = (f"+{i}", (nn := (moddigit(num, i, nv), i)))
if nn not in disallowed:
moves.append(ns)
if 0 <= (nv := (d - 1)): # note 0 is an allowed digit for the first position
ns = (f"-{i}", (nn := (moddigit(num, i, nv), i)))
if nn not in disallowed:
moves.append(ns)
return moves
def is_goal(self, state: State) -> bool:
return state[0] == (7, 7, 7)
class Jugs(DefaultModel):
State = Tuple[int, int, int]
def __init__(self, ke=(8, 0, 0)):
super().__init__(ke)
self.K1 = 8
self.K2 = 5
self.K3 = 3
def is_goal(self, state):
return 4 in state
def succ(self, state):
k1, k2, k3 = state
steps = list()
if k1 > 0 and k2 < self.K2:
m = min([k1, self.K2-k2])
steps.append(("k1-ből k2-be", (k1-m, k2+m, k3)))
if k1 > 0 and k3 < self.K3:
m = min([k1, self.K3 - k3])
steps.append(("k1-ből k3-ba", (k1-m, k2, k3+m)))
if k2 > 0 and k3 < self.K3:
m = min([k2, self.K3 - k3])
steps.append(("k2-ből k3-ba", (k1, k2-m, k3 + m)))
if k2 > 0 and k1 < self.K1:
m = min([k2, self.K1 - k1])
steps.append(("k2-ből k1-be", (k1+m, k2 - m, k3)))
if k3 > 0 and k1 < self.K1:
m = min([k3, self.K1 - k1])
steps.append(("k3-ből k1-be", (k1+m, k2, k3-m)))
if k3 > 0 and k2 < self.K2:
m = min([k3, self.K2 - k2])
steps.append(("k3-ből k2-be", (k1, k2+m, k3 - m)))
return steps
def __str__(self):
return 'Kancsó:' + str(self.start)
def value(self, state: _State): # TODO
pass
class Fruit(DefaultModel):
State = Tuple[int, int, int]
def __init__(self, start=(13, 46, 59)):
super().__init__(start)
def is_goal(self, state: State) -> bool:
return len([x for x in state if x != 0]) == 1
def succ(self, state: State) -> Iterable[Tuple[Op, State]]:
apple, pear, peach = state
steps = list()
if apple > 0 and pear > 0:
steps.append(("apple and pear for peach", (apple-1, pear-1, peach+2)))
if pear > 0 and peach > 0:
steps.append(("pear and peach for apple", (apple+2, pear-1, peach-1)))
if peach > 0 and apple > 0:
steps.append(("peach and apple for pear", (apple-1, pear+2, peach-1)))
return steps
def value(self, state: State): # TODO
return 118-max(state)
class Bishops(DefaultModel):
State = Tuple[int, Tuple[Tuple[Tuple[int, int]]]]
def __init__(self, start=None):
# see succ for explanations of the state space
super().__init__(start if start else
(1, (((0, 0), (1, 1), (2, 2), (3, 3)),
((-4, 4), (-3, 5), (-2, 6), (-1, 7)))))
self.pieces = 4
self.next = {1: 2, 2: 1}
self.w = 4
self.h = 5
def is_goal(self, state: State) -> bool:
_, (pl1_2, pl2_2) = state
_, (pl1, pl2) = self.start.state
return pl1_2 == pl2 and pl2_2 == pl1
# - states are represented as which player's turn it is, as well as x-y coordinates for the pieces of each player
# - we use the euclidean coordinate system of the diagonals, to make threat checks easy
# (thus with half-spacing, movement is thus over odd or even coordinates)
# - bounds checking is done by bounding from the bottom with an abs function, and from the top with a
# translated abs function
# - threat checking is by checking is after a move any of the coordinates diagonals of our pieces are the same
# as the diagonals of the other player
# The problem can also be factored into two subproblems because a bishop cant change the color of the grid squares
# it traverses - but we don't do this (yet? TODO)
#@profile
def succ(self, state: State) -> Iterable[Tuple[Op, State]]:
# constants for the bounds check geometry
assert self.h > self.w
skew = -(self.h - self.w)
diag = 2 * (self.h - 1)
pl, players = state
pl -= 1
moves = list()
directions = ((2, 0), (-2, 0), (0, 2), (0, -2))
allcoords = sum(players, tuple())
for i, (_x, _y) in enumerate(players[pl]): # for each of our pieces x coordinates
othera, otherb = players[:pl], players[pl + 1:]
otherx, othery = zip(*sum(othera + otherb, tuple()))
for direction in directions: # for each possible movement direction
# for each possible movement range;
# we do a bounds check similar to ray tracing, because it's easier than
# trying to deduce the max range
for dist in range(1, 3):
x = _x + dist * direction[0]
# bounds check based on the geometry of the abs function
# (like drawing the rectangle tilted 45 deg)
if not abs(x) <= (y := _y + dist * direction[1]) <= -abs(x-skew)+diag+skew:
break
# collision checking
if (x, y) in allcoords:
continue
# if any of our new diagonals match any of the enemy diagonals
# they threaten each other, so skip this alternative
# TODO this should be safe by induction?
if x in otherx or y in othery:
continue
newcoords: Tuple[Tuple[int, int]] = players[pl][:i] + ((x, y),) + players[pl][i + 1:]
moves.append((f"[{_x},{_y}] to [{x},{y}]",
(self.next[pl+1], othera + (newcoords,) + otherb)))
return moves
def value(self, state: State): # TODO
return super().value(state)
def print_state(self, state: State):
pl, players = state
t = sum(([[0]*self.w] for _ in range(self.h)), [])
for i, coords in enumerate(players):
for x, y in coords:
# linear transform: rotate and scale with 45 deg rotation matrix
# 2 / sqrt(2) * [ cos(a) sin(a) ; -sin(a) cos(a) ] where a=45deg == [ 1 1 ; -1 1 ]
# to transform diagonal coordinate system to normal
# note which index/coord is row and which is column
t[(-x+y)//2][(x+y)//2] = i+1
print(f"{pl} to move")
for lst in reversed(t):
print(" ".join(map(str, lst)))
class Queens(DefaultModel):
State = Tuple[int]
def __init__(self, start=None):
super().__init__(start if start else tuple())
def is_goal(self, state: State) -> bool:
return len(state) == 8
#TODO use permutations instead?
def succ(self, state: State) -> Iterable[Tuple[Op, _State]]:
coords = [(i - state[i], i + state[i]) for i in range(len(state))]
xs, ys = zip(*coords) if coords else ([],[])
steps = list()
i = len(state) # Note 1-based
for j in range(8):
if i - j in xs or i + j in ys or j in state:
continue
steps.append((f"queen at {j}",state+(j,)))
return steps
def value(self, state: State): # TODO
pass
def print_state(self, state: State):
lines = [list(["_"]*8) for _ in range(8)]
for i,j in enumerate(state):
lines[j][i] = "Q"
for line in lines:
print(" ".join(line))
print()
class Hanoi(DefaultModel):
State = Tuple[int, int, int]
def __init__(self, start=(1, 1, 1)):
super().__init__(start)
def is_goal(self, state: State) -> bool:
return state == (3, 3, 3)
def succ(self, state: State) -> Iterable[Tuple[Op, State]]:
moves = list()
# for i, rod in enumerate(state):
# if i == min([j for j, x in enumerate(state) if x == rod]):
# for l, z in enumerate(state):
# if i < l and z != rod:
# moves.append((f"move {i} from {rod} to {z}", tuple(z+1 if u == i else v for u, v in enumerate(state))))
def find(x):
iter = ((i, y) for i, y in enumerate(state) if x == y)
try:
i, _ = next(iter)
return i
except (StopIteration, ValueError):
return -1
f1, f2, f3 = find(1), find(2), find(3)
def f(a, b, c):
if -1 < a and (a < b or b == -1):
moves.append((f"{a} to {c}", state[0:a] + (c,) + state[a+1:]))
f(f1, f2, 2)
f(f1, f3, 3)
f(f2, f1, 1)
f(f2, f3, 3)
f(f3, f1, 1)
f(f2, f2, 2)
return moves
def value(self, state: State):
pass
class Test(unittest.TestCase):
@staticmethod
def test_soln_fruits():
f = Fruit()
s = Graph(f)
r = list(s.bfs())
for soln in r:
assert set(soln.state) == {118, 0}
s.print_stats()
s.reset_stats()
r = list(s.dfs())
for soln in r:
assert set(soln.state) == {118, 0}
s.print_stats()
s = Tree(f)
soln = next(s.best_first(f.value))
assert set(soln.state) == {118, 0}
s.print_stats()
@staticmethod
def test_soln_bishops():
b = Bishops()
search = Graph(b)
# s: List[Node] = next(search.dfs())
# for p in s.as_solution():
# b.print_state(p.state)
# search.print_stats()
search.reset_stats()
s: List[Node] = next(search.bfs())
for p in s.as_solution():
b.print_state(p.state)
search.print_stats()
@staticmethod
def test_soln_queens():
q = Queens()
s = Graph(q)
l = list(s.bfs())
for e in l:
q.print_state(e.state)
print("\n")
s.print_stats()
@staticmethod
def test_soln_hanoi():
q = Hanoi()
s = Graph(q)
l = list(s.bfs())
print(l)
s.print_stats()
@staticmethod
def test_soln_numbers():
n = Numbers()
s = Graph(n)
l = list(s.bfs())
print(l)
s.print_stats()
for i in l:
print(i.as_solution())
def test_soln_husbands(self):
h = Husbands()
s = Graph(h)
r = next(s.bfs())
print(r.as_solution())
print(r.as_solution2())
s.print_stats()
s.reset_stats()
r = next(s.best_first(s.model.value))
print(r.as_solution())
print(r.as_solution2())
s.print_stats()
s.reset_stats()
from itertools import product
class Husbands(DefaultModel):
State = Tuple[int, int, int, int, int, int, int]
def __init__(self, start=None):
super().__init__(start if start else (-1, -1, -1, -1, -1, -1, -1))
def is_goal(self, state: State) -> bool:
# Technically we don't care where the boat ends up,
# though it's impossible for it to end up on the other side
return state[:-1] == (1, 1, 1, 1, 1, 1)
def succ(self, state: State) -> Iterable[Tuple[Op, State]]:
side = state[-1]
moves = list()
def move(state: Husbands.State, i: int, j: int, side: int) -> Tuple[Op, Husbands.State]:
def _move(_state, x, _side):
return _state[:x] + (_side,) + _state[x+1:]
basemove = _move(_move(state, i, side), 6, side)
return (f"move {i} to {side}", basemove) if j is None\
else (f"move {i},{j} to {side}", _move(basemove, j, side))
def valid(state: Tuple[Op, Husbands.State]) -> bool:
state = state[1] # TODO awkward
def implies(a, b):
return not a or b
return all(implies(state[i] != state[i+1],
state[i + 1] != state[j])
for i in {1, 3, 5}
for j in ({1, 3, 5} - {i}))
#we use set() to remove duplicates because I'm lazy
# this can probably be done better
for i, j in set([(x, y) for x, y in product(range(6), [None, 0, 1, 2, 3, 4, 5, 6]) if x != y]):
if state[i] == side and (j is None or state[j] == side) and valid(m := move(state, i, j, -side)):
moves.append(m)
return moves
def value(self, state: State):
# TODO heuristic is probably bad because the problem solution is probably not monotonic
return 6 - len([x for i, x in enumerate(state) if x == 1 and i % 2 == 1])
if __name__ == "__main__":
# print(timeit.timeit(lambda: Test().test_soln_fruits(), number=1))
# print(timeit.timeit(lambda: Test().test_soln_bishops(), number=1))
# print(timeit.timeit(lambda: Test().test_soln_queens(), number=1))
#print(timeit.timeit(lambda: Test().test_soln_hanoi(), number=1))
#print(timeit.timeit(lambda: Test().test_soln_numbers(), number=1))
print(timeit.timeit(lambda: Test().test_soln_husbands(), number=1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment