Created
June 22, 2017 11:42
-
-
Save planaria/6ce92cd78c0974716c3db722fcd6a9e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <type_traits> | |
#include <vector> | |
#include <unordered_set> | |
#include <set> | |
#include <functional> | |
#include <boost/optional.hpp> | |
namespace astarpp | |
{ | |
template < | |
class State, | |
class Action, | |
class Enumerator, | |
class Reducer, | |
class Heuristic, | |
class Goal, | |
class Hash = std::hash<State>, | |
class Equal = std::equal_to<State> | |
> | |
class astar | |
{ | |
public: | |
typedef State state_type; | |
typedef Action action_type; | |
typedef Enumerator enumerator_type; | |
typedef Reducer reducer_type; | |
typedef Heuristic heuristic_type; | |
typedef Goal goal_type; | |
typedef Hash hash_type; | |
typedef Equal equal_type; | |
typedef typename std::result_of<heuristic_type(const state_type&)>::type cost_type; | |
typedef boost::optional<std::vector<std::pair<action_type, state_type>>> result_type; | |
astar( | |
enumerator_type enumerator, | |
reducer_type reducer, | |
heuristic_type heuristic, | |
goal_type goal, | |
hash_type hash = hash_type(), | |
equal_type equal = equal_type()) | |
: enumerator_(enumerator_) | |
, reducer_(reducer) | |
, heuristic_(heuristic) | |
, goal_(goal) | |
, nodes_(0, node_hash(hash), node_equal(equal)) | |
{ | |
} | |
void add(const state_type& initial_state) | |
{ | |
node initial_node; | |
initial_node.state = initial_state; | |
initial_node.cost_h = heuristic_(initial_state); | |
queue_.insert(&*nodes_.insert(initial_node).first); | |
} | |
bool step() | |
{ | |
if (queue_.empty()) | |
return false; | |
if (result_) | |
return false; | |
const auto& n = **queue_.begin(); | |
queue_.erase(queue_.begin()); | |
if (goal_(n.state)) | |
{ | |
std::vector<std::pair<action_type, state_type>> actions; | |
auto p = &n; | |
while (p->prev) | |
{ | |
actions.push_back(std::make_pair(p->action, p->state)); | |
p = p->prev; | |
} | |
std::reverse(actions.begin(), actions.end()); | |
result_ = std::move(actions); | |
return false; | |
} | |
enumerator_(n.state, [&](const action_type& action, cost_type cost) | |
{ | |
node new_node; | |
new_node.state = reducer_(n.state, action); | |
new_node.prev = &n; | |
new_node.action = action; | |
new_node.cost = n.cost + cost; | |
new_node.cost_h = new_node.cost + heuristic_(new_node.state); | |
auto it = nodes_.find(new_node); | |
if (it == nodes_.end()) | |
{ | |
queue_.insert(&*nodes_.insert(new_node).first); | |
} | |
else | |
{ | |
if (new_node.cost_h < it->cost_h) | |
{ | |
queue_.erase(&*it); | |
it->prev = new_node.prev; | |
it->action = new_node.action; | |
it->cost = new_node.cost; | |
it->cost_h = new_node.cost_h; | |
queue_.insert(&*it); | |
} | |
} | |
}); | |
return true; | |
} | |
const result_type& result() const | |
{ | |
return result_; | |
} | |
result_type detach_result() | |
{ | |
result_type result = std::move(result_); | |
result_ = boost::none; | |
return result; | |
} | |
private: | |
struct node | |
{ | |
state_type state; | |
mutable const node* prev = nullptr; | |
mutable action_type action = action_type(); | |
mutable cost_type cost = cost_type(); | |
mutable cost_type cost_h = cost_type(); | |
}; | |
class node_hash | |
{ | |
public: | |
explicit node_hash(hash_type hash_) | |
: hash_(hash_) | |
{ | |
} | |
std::size_t operator ()(const node& n) const | |
{ | |
return hash_(n.state); | |
} | |
private: | |
hash_type hash_; | |
}; | |
class node_equal | |
{ | |
public: | |
explicit node_equal(equal_type equal_) | |
: equal_(equal_) | |
{ | |
} | |
bool operator ()(const node& lhs, const node& rhs) const | |
{ | |
return equal_(lhs.state, rhs.state); | |
} | |
private: | |
equal_type equal_; | |
}; | |
struct node_compare | |
{ | |
bool operator ()(const node* lhs, const node* rhs) const | |
{ | |
if (lhs->cost_h < rhs->cost_h) | |
return true; | |
if (rhs->cost_h < lhs->cost_h) | |
return false; | |
return lhs < rhs; | |
} | |
}; | |
enumerator_type enumerator_; | |
reducer_type reducer_; | |
heuristic_type heuristic_; | |
goal_type goal_; | |
std::unordered_set<node, node_hash, node_equal> nodes_; | |
std::multiset<const node*, node_compare> queue_; | |
result_type result_; | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment