Skip to content

Instantly share code, notes, and snippets.

@planaria
Created June 22, 2017 11:42
Show Gist options
  • Save planaria/6ce92cd78c0974716c3db722fcd6a9e1 to your computer and use it in GitHub Desktop.
Save planaria/6ce92cd78c0974716c3db722fcd6a9e1 to your computer and use it in GitHub Desktop.
#pragma once
#include <type_traits>
#include <vector>
#include <unordered_set>
#include <set>
#include <functional>
#include <boost/optional.hpp>
namespace astarpp
{
template <
class State,
class Action,
class Enumerator,
class Reducer,
class Heuristic,
class Goal,
class Hash = std::hash<State>,
class Equal = std::equal_to<State>
>
class astar
{
public:
typedef State state_type;
typedef Action action_type;
typedef Enumerator enumerator_type;
typedef Reducer reducer_type;
typedef Heuristic heuristic_type;
typedef Goal goal_type;
typedef Hash hash_type;
typedef Equal equal_type;
typedef typename std::result_of<heuristic_type(const state_type&)>::type cost_type;
typedef boost::optional<std::vector<std::pair<action_type, state_type>>> result_type;
astar(
enumerator_type enumerator,
reducer_type reducer,
heuristic_type heuristic,
goal_type goal,
hash_type hash = hash_type(),
equal_type equal = equal_type())
: enumerator_(enumerator_)
, reducer_(reducer)
, heuristic_(heuristic)
, goal_(goal)
, nodes_(0, node_hash(hash), node_equal(equal))
{
}
void add(const state_type& initial_state)
{
node initial_node;
initial_node.state = initial_state;
initial_node.cost_h = heuristic_(initial_state);
queue_.insert(&*nodes_.insert(initial_node).first);
}
bool step()
{
if (queue_.empty())
return false;
if (result_)
return false;
const auto& n = **queue_.begin();
queue_.erase(queue_.begin());
if (goal_(n.state))
{
std::vector<std::pair<action_type, state_type>> actions;
auto p = &n;
while (p->prev)
{
actions.push_back(std::make_pair(p->action, p->state));
p = p->prev;
}
std::reverse(actions.begin(), actions.end());
result_ = std::move(actions);
return false;
}
enumerator_(n.state, [&](const action_type& action, cost_type cost)
{
node new_node;
new_node.state = reducer_(n.state, action);
new_node.prev = &n;
new_node.action = action;
new_node.cost = n.cost + cost;
new_node.cost_h = new_node.cost + heuristic_(new_node.state);
auto it = nodes_.find(new_node);
if (it == nodes_.end())
{
queue_.insert(&*nodes_.insert(new_node).first);
}
else
{
if (new_node.cost_h < it->cost_h)
{
queue_.erase(&*it);
it->prev = new_node.prev;
it->action = new_node.action;
it->cost = new_node.cost;
it->cost_h = new_node.cost_h;
queue_.insert(&*it);
}
}
});
return true;
}
const result_type& result() const
{
return result_;
}
result_type detach_result()
{
result_type result = std::move(result_);
result_ = boost::none;
return result;
}
private:
struct node
{
state_type state;
mutable const node* prev = nullptr;
mutable action_type action = action_type();
mutable cost_type cost = cost_type();
mutable cost_type cost_h = cost_type();
};
class node_hash
{
public:
explicit node_hash(hash_type hash_)
: hash_(hash_)
{
}
std::size_t operator ()(const node& n) const
{
return hash_(n.state);
}
private:
hash_type hash_;
};
class node_equal
{
public:
explicit node_equal(equal_type equal_)
: equal_(equal_)
{
}
bool operator ()(const node& lhs, const node& rhs) const
{
return equal_(lhs.state, rhs.state);
}
private:
equal_type equal_;
};
struct node_compare
{
bool operator ()(const node* lhs, const node* rhs) const
{
if (lhs->cost_h < rhs->cost_h)
return true;
if (rhs->cost_h < lhs->cost_h)
return false;
return lhs < rhs;
}
};
enumerator_type enumerator_;
reducer_type reducer_;
heuristic_type heuristic_;
goal_type goal_;
std::unordered_set<node, node_hash, node_equal> nodes_;
std::multiset<const node*, node_compare> queue_;
result_type result_;
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment