Skip to content

Instantly share code, notes, and snippets.

@RJ
Created March 19, 2025 10:29
Show Gist options
  • Save RJ/40f202e79e9b24efa4e47bba706326d2 to your computer and use it in GitHub Desktop.
Save RJ/40f202e79e9b24efa4e47bba706326d2 to your computer and use it in GitHub Desktop.
A test case from my bevy HTN crate
#[test]
fn test_travel_htn() {
{
// Don't need app for test, just want to set up the logger.
let mut app = App::new();
app.add_plugins(bevy::log::LogPlugin::default());
}
// DEFINE OPERATORS (which are behaviour trees)
// the default behaviour of operators is to be emitted as triggers,
// ie. Behave::trigger(WalkOperator)
// but you can also make them components to be spawned, or implement HtnOperator
// yourself to provide a custom behaviour tree.
#[derive(Reflect, Default, Clone, Debug, PartialEq, Eq, HtnOperator)]
#[reflect(Default, HtnOperator)]
struct WalkOperator;
#[derive(Reflect, Default, Clone, Debug, PartialEq, Eq, HtnOperator)]
#[reflect(Default, HtnOperator)]
struct TaxiOperator;
// an operator that returns a custom behaviour tree.
#[derive(Reflect, Default, Clone, Debug, PartialEq, Eq)]
#[reflect(Default, HtnOperator)]
struct RideTaxiOperator;
impl HtnOperator for RideTaxiOperator {
fn to_tree(&self) -> Option<Tree<Behave>> {
Some(behave! { Behave::Wait(3.0) })
}
}
// this one would get spawned into an entity using Behave::spawn_named.
#[derive(Reflect, Default, Clone, Debug, PartialEq, Eq, HtnOperator, Component)]
#[reflect(Default, HtnOperator)]
#[spawn_named = "Paying the taxi!"]
struct PayTaxiOperator;
// DEFINE PLANNER STATE
#[derive(Reflect, Default, Clone, Debug, PartialEq, Eq)]
#[reflect(Default)]
enum Location {
#[default]
Home,
Other,
Park,
}
#[derive(Reflect, Resource, Clone, Debug, Default)]
#[reflect(Default, Resource)]
struct TravelState {
cash: i32,
distance_to_park: i32,
happy: bool,
my_location: Location,
taxi_location: Location,
}
// DEFINE HTN (can be loaded from an .htn file by asset loader too)
// for an initial state with distance > 4 this will cause the planner to try the first
// TravelToPark method, then backtrack when the precondition for Walk is not met,
// then try the second method, which succeeds (get a taxi).
let src = r#"
schema {
version: 0.1.0
}
compound_task "TravelToPark" {
method {
subtasks: [ Walk ]
}
method {
subtasks: [ Taxi ]
}
}
primitive_task "Walk" {
operator: WalkOperator
preconditions: [distance_to_park <= 4, my_location != Location::Park, happy == false]
effects: [
my_location = Location::Park,
happy = true,
]
}
compound_task "Taxi" {
method {
subtasks: [CallTaxi, RideTaxi, PayTaxi]
}
}
primitive_task "CallTaxi" {
operator: TaxiOperator
preconditions: [taxi_location != Location::Park, cash >= 1]
effects: [taxi_location = Location::Home]
}
primitive_task "RideTaxi" {
operator: RideTaxiOperator
preconditions: [taxi_location == Location::Home, cash >= 1]
effects: [taxi_location = Location::Park, my_location = Location::Park, happy = true]
}
primitive_task "PayTaxi" {
operator: PayTaxiOperator
preconditions: [taxi_location == Location::Park, cash >= 1]
effects: [cash -= 1]
}
"#;
// REGISTER TYPES USED IN HTN
// normally you'd use app.register_type or Res<AppTypeRegistry>
let atr = AppTypeRegistry::default();
{
let mut atr = atr.write();
atr.register::<TravelState>();
atr.register::<Location>();
atr.register::<WalkOperator>();
atr.register::<TaxiOperator>();
atr.register::<RideTaxiOperator>();
atr.register::<PayTaxiOperator>();
}
let htn = parse_htn::<TravelState>(src);
// verify via reflection that any types used in the htn are registered:
match htn.verify_all(&TravelState::default(), &atr) {
Ok(_) => {}
Err(e) => panic!("HTN type verification failed: {e:?}"),
}
let mut planner = HtnPlanner::new(&htn, &atr);
// Run the planner with alternative starting states to see different outcomes:
{
warn!("Testing walking state");
let initial_state = TravelState {
cash: 10,
distance_to_park: 1,
my_location: Location::Home,
taxi_location: Location::Other,
happy: false,
};
let plan = planner.plan(&initial_state);
assert_eq!(plan.task_names(), vec!["Walk"]);
}
{
warn!("Testing taxi state");
let initial_state = TravelState {
cash: 10,
distance_to_park: 5,
my_location: Location::Home,
taxi_location: Location::Other,
happy: false,
};
let plan = planner.plan(&initial_state);
assert_eq!(plan.task_names(), vec!["CallTaxi", "RideTaxi", "PayTaxi"]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment