Created
October 4, 2022 20:11
-
-
Save zephirefaith/bdb7ea164c23e6494582ea9a9d235f42 to your computer and use it in GitHub Desktop.
Python code to instantiate a simple 4 node HTN for an agent in 2D grid. The agent is spawned randomly somewhere in the world, the HTN encodes decisions needed to get the agent to top-right corner of the grid.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/env python | |
import networkx as nx | |
import numpy as np | |
from enum import Enum | |
class Actions(Enum): | |
MOVE_UP = 0 | |
MOVE_RIGHT = 1 | |
class States(Enum): | |
LEFT_BOTTOM = 0 | |
LEFT_TOP = 1 | |
RIGHT_BOTTOM = 2 | |
RIGHT_TOP = 3 | |
class MRPEnv(object): | |
def __init__(self): | |
self.x = None # agent-state | |
self.init_agent() | |
def init_agent(self): | |
self.x = np.array( | |
[ | |
10.00 * np.random.random_sample(), | |
10.00 * np.random.random_sample(), | |
] | |
) | |
def step(self, action: str): | |
if action == Actions.MOVE_UP: | |
self.x[1] = np.clip( | |
self.x[1] + 2 * np.random.random_sample(), 0, 10 | |
) | |
elif action == Actions.MOVE_RIGHT: | |
self.x[0] = np.clip( | |
self.x[0] + 1.5 * np.random.random_sample(), 0, 10 | |
) | |
def is_done(self): | |
return (10.00 - self.x[0]) < 0.25 and (10.00 - self.x[1]) < 0.25 | |
class TaskGraph(object): | |
def __init__(self) -> None: | |
self._V = [] | |
for state in States: | |
self._V.append((state.value, {"state": state.name})) | |
self._E = [ | |
(0, 1, {"action": Actions.MOVE_UP}), | |
(0, 2, {"action": Actions.MOVE_RIGHT}), | |
(1, 3, {"action": Actions.MOVE_RIGHT}), | |
(2, 3, {"action": Actions.MOVE_UP}), | |
(3, 3, {"action": Actions.MOVE_RIGHT}), | |
(3, 3, {"action": Actions.MOVE_UP}), | |
] | |
self.GT = None | |
self.init_graph() | |
def init_graph(self): | |
self.GT = nx.MultiDiGraph() | |
self.GT.add_nodes_from(self._V) | |
self.GT.add_edges_from(self._E) | |
def get_current_node(self, x): | |
if x[0] <= 5.00: | |
if x[1] <= 5.00: | |
return States.LEFT_BOTTOM | |
else: | |
return States.LEFT_TOP | |
else: | |
if x[1] <= 5.00: | |
return States.RIGHT_BOTTOM | |
else: | |
return States.RIGHT_TOP | |
def get_next_action(self, x): | |
rng = np.random.default_rng() | |
current_v = self.get_current_node(x) | |
action_list = [ | |
e for e in self.GT.out_edges(current_v.value, data=True) | |
] | |
action_num = self.GT.out_degree(current_v.value) | |
if action_num == 1: | |
return action_list[0][2]["action"] | |
else: | |
return rng.choice(action_list)[2]["action"] | |
if __name__ == "__main__": | |
# create env | |
env = MRPEnv() | |
task_graph = TaskGraph() | |
while not env.is_done(): | |
curr_state = task_graph.get_current_node(env.x) | |
print(f"Current state: {env.x}, {curr_state}") | |
next_action = task_graph.get_next_action(env.x) | |
env.step(next_action) | |
print(f"Next action: {next_action}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment