Skip to content

Instantly share code, notes, and snippets.

@mdamien
Created June 17, 2015 15:32
Show Gist options
  • Save mdamien/4cdb90b5ce9c4964a755 to your computer and use it in GitHub Desktop.
Save mdamien/4cdb90b5ce9c4964a755 to your computer and use it in GitHub Desktop.
A* heuristic testing
from pprint import pprint as pp
from collections import Counter
import sys
class Node:
def __init__(self,p, parent=None, descr=""):
self.p = p
self.parent = parent
self.descr = descr
def print_tree(self,n=0):
if self.parent:
self.parent.print_tree(n+1)
print("->",n,self.p, "\t:",self.descr,"\th=",h(self.p))
def __hash__(self):
return hash(self.p)
def __str__(self):
return self.p
R = """
+-:-+
-+:+-
+-+:++
-+-:--
++:+++
--:---"""
R = [x.strip().split(':') for x in R.split('\n') if len(x.strip()) > 0]
R = {x[0]: x[1] for x in R}
def opp(x):
return '-' if x == '+' else '+'
def encercle(p,i,x):
return opp(x) in p[:i-1] and opp(x) in p[i+1:]
def h(p):
c = Counter(p)
a = min(c['+'],c['-'])
b = Counter()
for i,x in enumerate(p):
if i == 0 or i == len(p)-1:
b[x] += 1
elif not encercle(p,i,x):
b[x] += 1
return a + min(b['+'],b['-'])+len(p)
def finished(p):
return '-' not in p or '+' not in p
def apply_rules(node):
p = node.p
for r in R:
i = 0
while p[i:].find(r) > -1:
g = Node(p[:i]+p[i:].replace(r,R[r]),parent=node,
descr="Applied "+r+" -> "+R[r])
yield g
i += p[i:].find(r)+1
def search_a_star(start):
print()
print("A* SEARCH")
print("----------")
N = 0
to_explore = [Node(start, descr="start node")]
explored = set()
while len(to_explore) > 0:
to_explore = sorted(to_explore, key=lambda n:h(n.p))
node = to_explore[0]
explored.add(node)
if finished(node.p):
N += 1
node.print_tree()
print("A* Explored:", len(explored))
print()
if N > 3:
return
nps = set(apply_rules(node))
to_explore = (set(to_explore) | nps ) - explored
def search_basic(start):
print('BASIC SEARCH')
print("------------")
N = 0
to_explore = {Node(start, descr="start node")}
explored = set()
while len(to_explore) > 0:
next_to_explore = set()
for node in to_explore:
explored.add(node)
if finished(node.p):
N += 1
node.print_tree()
print("BASIC explored:", len(explored))
print()
if N > 3:
return
nps = set(apply_rules(node))
next_to_explore |= nps
to_explore = next_to_explore - explored
S = '+++--' if len(sys.argv) < 2 else sys.argv[1]
print(S)
search_basic(S)
search_a_star(S)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment