Created
October 17, 2016 02:15
-
-
Save vzhong/f006894abc0d626c21394dfa943a4c42 to your computer and use it in GitHub Desktop.
Basic search algorithms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
backtracking extended list informed | |
Depth First Search: y y n | |
Breadth First Search: n y n | |
Hill Climbing y y y | |
Beam Search: y y y | |
""" | |
import heapq | |
class SearchAlgorithm(object): | |
def __init__(self, get_actions, take_action, prune_seen_states=True): | |
super().__init__() | |
self.get_actions = get_actions | |
self.take_action = take_action | |
self.prune_seen_states = prune_seen_states | |
def initialize_queue(self): | |
return [] | |
def extend_state(self, state, seen): | |
new_states = [] | |
for a in self.get_actions(state): | |
new_state = self.take_action(state, a) | |
if self.prune_seen_states and new_state.current in seen: | |
continue | |
seen.add(new_state.current) | |
new_states.append(new_state) | |
return new_states | |
def enqueue(self, queue, new_states): | |
raise NotImplementedError() | |
def dequeue(self, queue): | |
return queue.pop(0) | |
def reorder_new_states(self, new_states): | |
return new_states | |
def prune_queue(self, queue): | |
pass | |
def __call__(self, state, terminate, callback=None): | |
queue = self.initialize_queue() | |
self.enqueue(queue, [state]) | |
seen = set() | |
while len(queue): | |
s = self.dequeue(queue) | |
if callback is not None: | |
callback(s) | |
if terminate(s): | |
return s | |
new_states = self.extend_state(s, seen) | |
new_states = self.reorder_new_states(new_states) | |
self.enqueue(queue, new_states) | |
self.prune_queue(queue) | |
return None | |
class DepthFirstSearch(SearchAlgorithm): | |
def enqueue(self, queue, new_states): | |
for s in new_states: | |
queue.insert(0, s) | |
class BreadthFirstSearch(SearchAlgorithm): | |
def enqueue(self, queue, new_states): | |
for s in new_states: | |
queue.append(s) | |
class HillClimbingSearch(DepthFirstSearch): | |
# DFS but break ties by considering which one is closer to the goal | |
def __init__(self, get_actions, take_action, heuristic, prune_seen_states=True): | |
super().__init__(get_actions, take_action, prune_seen_states=prune_seen_states) | |
self.heuristic = heuristic | |
def reorder_new_states(self, new_states): | |
return sorted(new_states, key=self.heuristic) | |
class BeamSearch(BreadthFirstSearch): | |
# DFS but break ties by considering which one is closer to the goal | |
def __init__(self, get_actions, take_action, heuristic, beam_size, prune_seen_states=True): | |
super().__init__(get_actions, take_action, prune_seen_states=prune_seen_states) | |
self.heuristic = heuristic | |
self.beam_size = beam_size | |
def prune_queue(self, queue): | |
states_and_scores = [(s, self.heuristic(s)) for s in queue] | |
top_states_and_scores = heapq.nlargest(self.beam_size, states_and_scores, key=lambda tup: tup[1]) | |
queue.clear() | |
for state, score in top_states_and_scores: | |
queue.append(state) | |
if __name__ == '__main__': | |
import random | |
import networkx as nx | |
from collections import namedtuple, defaultdict | |
State = namedtuple('State', ['current', 'history']) | |
def get_dists_to_target(g, target, dists=None): | |
if dists is None: | |
dists = defaultdict(lambda: float('inf')) | |
dists[target] = 0 | |
for n in g.neighbors(target): | |
weight = g[target][n]['weight'] | |
new_weight = dists[target] + weight | |
if new_weight < dists[n]: | |
dists[n] = new_weight | |
get_dists_to_target(g, n, dists) | |
return dists | |
def toy_graph(): | |
""" | |
C---E | |
| | |
4 | |
| | |
B G | |
/| / | |
5 4 5 | |
/ | / | |
S-3-A-3-D | |
""" | |
g = nx.Graph() | |
g.add_edge('S', 'A', weight=3) | |
g.add_edge('S', 'B', weight=5) | |
g.add_edge('A', 'B', weight=4) | |
g.add_edge('B', 'C', weight=4) | |
g.add_edge('C', 'E', weight=0) | |
g.add_edge('A', 'D', weight=3) | |
g.add_edge('D', 'G', weight=5) | |
def dist_to_target(state): | |
dists = {'G': 0} | |
dists['D'] = dists['G'] + 5 | |
dists['A'] = dists['D'] + 3 | |
dists['S'] = dists['A'] + 3 | |
dists['B'] = dists['A'] + 4 | |
dists['C'] = dists['B'] + 4 | |
dists['E'] = dists['C'] + 0 | |
other = get_dists_to_target(g, 'G') | |
for k, v in dists.items(): | |
assert other[k] == v, 'differ for k: {}, expect {}, got {}'.format(k, v, other[k]) | |
return -dists[state.current] | |
return g, 'S', 'G', dist_to_target | |
def large_graph(num_nodes=50, num_edges=10): | |
g = nx.Graph() | |
dists = {} | |
target = 0 | |
# add random edges | |
for i in range(num_edges): | |
start = end = random.randint(0, num_nodes-1) | |
while end == start: | |
end = random.randint(0, num_nodes-1) | |
weight = random.randint(1, 5) | |
g.add_edge(start, target, weight=weight) | |
# ensure that there is at least 1 path | |
for i in range(1, num_nodes): | |
g.add_edge(i-1, i, weight=5) | |
dists = get_dists_to_target(g, 0) | |
return g, 0, num_nodes-1, lambda state: -dists[state.current] | |
def get_controllers(g, end): | |
def get_actions(state): | |
actions = [n for n in g.neighbors(state.current) if n not in state.history] | |
return actions | |
def take_action(state, action): | |
return State(action, state.history + [state.current]) | |
def terminate(state): | |
return state.current == end | |
return get_actions, take_action, terminate | |
def run_algorithms(get_graph): | |
g, start, end, heuristic = get_graph() | |
get_actions, take_action, terminate = get_controllers(g, end) | |
start_state = State(start, []) | |
algs = [ | |
DepthFirstSearch(get_actions, take_action), | |
BreadthFirstSearch(get_actions, take_action), | |
HillClimbingSearch(get_actions, take_action, heuristic=heuristic), | |
BeamSearch(get_actions, take_action, heuristic=heuristic, beam_size=3), | |
] | |
print('heuristic ground truth: {}\n'.format(heuristic(State(end, [])))) | |
for alg in algs: | |
total = [0] | |
def callback(state): | |
total[0] += 1 | |
print('Using algorithm: {}'.format(alg.__class__.__name__)) | |
path = alg(start_state, terminate, callback=callback) | |
print('done in {} steps'.format(total[0])) | |
print(path) | |
print() | |
random.seed(0) | |
print('toy problem') | |
run_algorithms(toy_graph) | |
print('large graph problem') | |
run_algorithms(large_graph) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment