Created
June 17, 2015 15:32
-
-
Save mdamien/4cdb90b5ce9c4964a755 to your computer and use it in GitHub Desktop.
A* heuristic testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pprint import pprint as pp | |
from collections import Counter | |
import sys | |
class Node: | |
def __init__(self,p, parent=None, descr=""): | |
self.p = p | |
self.parent = parent | |
self.descr = descr | |
def print_tree(self,n=0): | |
if self.parent: | |
self.parent.print_tree(n+1) | |
print("->",n,self.p, "\t:",self.descr,"\th=",h(self.p)) | |
def __hash__(self): | |
return hash(self.p) | |
def __str__(self): | |
return self.p | |
R = """ | |
+-:-+ | |
-+:+- | |
+-+:++ | |
-+-:-- | |
++:+++ | |
--:---""" | |
R = [x.strip().split(':') for x in R.split('\n') if len(x.strip()) > 0] | |
R = {x[0]: x[1] for x in R} | |
def opp(x): | |
return '-' if x == '+' else '+' | |
def encercle(p,i,x): | |
return opp(x) in p[:i-1] and opp(x) in p[i+1:] | |
def h(p): | |
c = Counter(p) | |
a = min(c['+'],c['-']) | |
b = Counter() | |
for i,x in enumerate(p): | |
if i == 0 or i == len(p)-1: | |
b[x] += 1 | |
elif not encercle(p,i,x): | |
b[x] += 1 | |
return a + min(b['+'],b['-'])+len(p) | |
def finished(p): | |
return '-' not in p or '+' not in p | |
def apply_rules(node): | |
p = node.p | |
for r in R: | |
i = 0 | |
while p[i:].find(r) > -1: | |
g = Node(p[:i]+p[i:].replace(r,R[r]),parent=node, | |
descr="Applied "+r+" -> "+R[r]) | |
yield g | |
i += p[i:].find(r)+1 | |
def search_a_star(start): | |
print() | |
print("A* SEARCH") | |
print("----------") | |
N = 0 | |
to_explore = [Node(start, descr="start node")] | |
explored = set() | |
while len(to_explore) > 0: | |
to_explore = sorted(to_explore, key=lambda n:h(n.p)) | |
node = to_explore[0] | |
explored.add(node) | |
if finished(node.p): | |
N += 1 | |
node.print_tree() | |
print("A* Explored:", len(explored)) | |
print() | |
if N > 3: | |
return | |
nps = set(apply_rules(node)) | |
to_explore = (set(to_explore) | nps ) - explored | |
def search_basic(start): | |
print('BASIC SEARCH') | |
print("------------") | |
N = 0 | |
to_explore = {Node(start, descr="start node")} | |
explored = set() | |
while len(to_explore) > 0: | |
next_to_explore = set() | |
for node in to_explore: | |
explored.add(node) | |
if finished(node.p): | |
N += 1 | |
node.print_tree() | |
print("BASIC explored:", len(explored)) | |
print() | |
if N > 3: | |
return | |
nps = set(apply_rules(node)) | |
next_to_explore |= nps | |
to_explore = next_to_explore - explored | |
S = '+++--' if len(sys.argv) < 2 else sys.argv[1] | |
print(S) | |
search_basic(S) | |
search_a_star(S) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment