Last active
May 23, 2024 03:39
-
-
Save niektuytel/eab1117070454042b11e5e5c026dd3fb to your computer and use it in GitHub Desktop.
RLGlue for python 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RLGlue: | |
"""RLGlue class | |
args: | |
env_name (string): the name of the module where the Environment class can be found | |
agent_name (string): the name of the module where the Agent class can be found | |
""" | |
def __init__(self, env_class, agent_class): | |
self.environment = env_class() | |
self.agent = agent_class() | |
self.total_reward = None | |
self.last_action = None | |
self.num_steps = None | |
self.num_episodes = None | |
def rl_init(self, agent_init_info={}, env_init_info={}): | |
"""Initial method called when RLGlue experiment is created""" | |
self.environment.env_init(env_init_info) | |
self.agent.agent_init(agent_init_info) | |
self.total_reward = 0.0 | |
self.num_steps = 0 | |
self.num_episodes = 0 | |
def rl_start(self, agent_start_info={}, env_start_info={}): | |
"""Starts RLGlue experiment | |
Returns: | |
tuple: (state, action) | |
""" | |
last_state = self.environment.env_start() | |
self.last_action = self.agent.agent_start(last_state) | |
observation = (last_state, self.last_action) | |
return observation | |
def rl_agent_start(self, observation): | |
"""Starts the agent. | |
Args: | |
observation: The first observation from the environment | |
Returns: | |
The action taken by the agent. | |
""" | |
return self.agent.agent_start(observation) | |
def rl_agent_step(self, reward, observation): | |
"""Step taken by the agent | |
Args: | |
reward (float): the last reward the agent received for taking the | |
last action. | |
observation : the state observation the agent receives from the | |
environment. | |
Returns: | |
The action taken by the agent. | |
""" | |
return self.agent.agent_step(reward, observation) | |
def rl_agent_end(self, reward): | |
"""Run when the agent terminates | |
Args: | |
reward (float): the reward the agent received when terminating | |
""" | |
self.agent.agent_end(reward) | |
def rl_env_start(self): | |
"""Starts RL-Glue environment. | |
Returns: | |
(float, state, Boolean): reward, state observation, boolean | |
indicating termination | |
""" | |
self.total_reward = 0.0 | |
self.num_steps = 1 | |
this_observation = self.environment.env_start() | |
return this_observation | |
def rl_env_step(self, action): | |
"""Step taken by the environment based on action from agent | |
Args: | |
action: Action taken by agent. | |
Returns: | |
(float, state, Boolean): reward, state observation, boolean | |
indicating termination. | |
""" | |
ro = self.environment.env_step(action) | |
(this_reward, _, terminal) = ro | |
self.total_reward += this_reward | |
if terminal: | |
self.num_episodes += 1 | |
else: | |
self.num_steps += 1 | |
return ro | |
def rl_step(self): | |
"""Step taken by RLGlue, takes environment step and either step or | |
end by agent. | |
Returns: | |
(float, state, action, Boolean): reward, last state observation, | |
last action, boolean indicating termination | |
""" | |
(reward, last_state, term) = self.environment.env_step(self.last_action) | |
self.total_reward += reward | |
if term: | |
self.num_episodes += 1 | |
self.agent.agent_end(reward) | |
roat = (reward, last_state, None, term) | |
else: | |
self.num_steps += 1 | |
self.last_action = self.agent.agent_step(reward, last_state) | |
roat = (reward, last_state, self.last_action, term) | |
return roat | |
def rl_cleanup(self): | |
"""Cleanup done at end of experiment.""" | |
self.environment.env_cleanup() | |
self.agent.agent_cleanup() | |
def rl_agent_message(self, message): | |
"""Message passed to communicate with agent during experiment | |
Args: | |
message: the message (or question) to send to the agent | |
Returns: | |
The message back (or answer) from the agent | |
""" | |
return self.agent.agent_message(message) | |
def rl_env_message(self, message): | |
"""Message passed to communicate with environment during experiment | |
Args: | |
message: the message (or question) to send to the environment | |
Returns: | |
The message back (or answer) from the environment | |
""" | |
return self.environment.env_message(message) | |
def rl_episode(self, max_steps_this_episode): | |
"""Runs an RLGlue episode | |
Args: | |
max_steps_this_episode (Int): the maximum steps for the experiment to run in an episode | |
Returns: | |
Boolean: if the episode should terminate | |
""" | |
is_terminal = False | |
self.rl_start() | |
while (not is_terminal) and ((max_steps_this_episode == 0) or | |
(self.num_steps < max_steps_this_episode)): | |
rl_step_result = self.rl_step() | |
is_terminal = rl_step_result[3] | |
return is_terminal | |
def rl_return(self): | |
"""The total reward | |
Returns: | |
float: the total reward | |
""" | |
return self.total_reward | |
def rl_num_steps(self): | |
"""The total number of steps taken | |
Returns: | |
Int: the total number of steps taken | |
""" | |
return self.num_steps | |
def rl_num_episodes(self): | |
"""The number of episodes | |
Returns | |
Int: the total number of episodes | |
""" | |
return self.num_episodes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment