Last active
August 21, 2024 11:54
-
-
Save qxcv/e8641342c102c2aa714c9caeca724101 to your computer and use it in GitHub Desktop.
Gymnasium envs with Dreamer v3, along with a Minigrid example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from_gym.py adapted to work with Gymnasium. Differences: | |
# | |
# - gym.* -> gymnasium.* | |
# - Deals with .step() returning a tuple of (obs, reward, terminated, truncated, | |
# info) rather than (obs, reward, done, info). | |
# - Also deals with .reset() returning a tuple of (obs, info) rather than just | |
# obs. | |
# - Passes render_mode='rgb_array' to gymnasium.make() rather than .render(). | |
# - A bunch of minor/irrelevant type checking changes that stopped pyright from | |
# complaining (these have no functional purpose, I'm just a completionist who | |
# doesn't like red squiggles). | |
import functools | |
from typing import Any, Generic, TypeVar, Union, cast, Dict | |
import embodied | |
import gymnasium | |
import numpy as np | |
U = TypeVar('U') | |
V = TypeVar('V') | |
class FromGymnasium(embodied.Env, Generic[U, V]): | |
def __init__(self, env: Union[str, gymnasium.Env[U, V]], obs_key='image', act_key='action', **kwargs): | |
if isinstance(env, str): | |
self._env: gymnasium.Env[U, V] = gymnasium.make(env, render_mode="rgb_array", **kwargs) | |
else: | |
assert not kwargs, kwargs | |
assert env.render_mode == "rgb_array", f"render_mode must be rgb_array, got {self._env.render_mode}" | |
self._env = env | |
self._obs_dict = hasattr(self._env.observation_space, 'spaces') | |
self._act_dict = hasattr(self._env.action_space, 'spaces') | |
self._obs_key = obs_key | |
self._act_key = act_key | |
self._done = True | |
self._info = None | |
@property | |
def info(self): | |
return self._info | |
@functools.cached_property | |
def obs_space(self): | |
if self._obs_dict: | |
# cast is here to stop type checkers from complaining (we already check | |
# that .spaces attr exists in __init__ as a proxy for the type check) | |
obs_space = cast(gymnasium.spaces.Dict, self._env.observation_space) | |
spaces = obs_space.spaces | |
else: | |
spaces = {self._obs_key: self._env.observation_space} | |
spaces = {k: self._convert(v) for k, v in spaces.items()} | |
return { | |
**spaces, | |
'reward': embodied.Space(np.float32), | |
'is_first': embodied.Space(bool), | |
'is_last': embodied.Space(bool), | |
'is_terminal': embodied.Space(bool), | |
} | |
@functools.cached_property | |
def act_space(self): | |
if self._act_dict: | |
act_space = cast(gymnasium.spaces.Dict, self._env.action_space) | |
spaces = act_space.spaces | |
else: | |
spaces = {self._act_key: self._env.action_space} | |
spaces = {k: self._convert(v) for k, v in spaces.items()} | |
spaces['reset'] = embodied.Space(bool) | |
return spaces | |
def step(self, action): | |
if action['reset'] or self._done: | |
self._done = False | |
# we don't bother setting ._info here because it gets set below, once we | |
# take the next .step() | |
obs, _ = self._env.reset() | |
return self._obs(obs, 0.0, is_first=True) | |
if self._act_dict: | |
gymnasium_action = cast(V, self._unflatten(action)) | |
else: | |
gymnasium_action = cast(V, action[self._act_key]) | |
obs, reward, terminated, truncated, self._info = self._env.step(gymnasium_action) | |
self._done = terminated or truncated | |
return self._obs( | |
obs, reward, | |
is_last=bool(self._done), | |
is_terminal=bool(self._info.get('is_terminal', self._done))) | |
def _obs( | |
self, obs, reward, is_first=False, is_last=False, is_terminal=False): | |
if not self._obs_dict: | |
obs = {self._obs_key: obs} | |
obs = self._flatten(obs) | |
np_obs: Dict[str, Any] = {k: np.asarray(v) for k, v in obs.items()} | |
np_obs.update( | |
reward=np.float32(reward), | |
is_first=is_first, | |
is_last=is_last, | |
is_terminal=is_terminal) | |
return np_obs | |
def render(self): | |
image = self._env.render() | |
assert image is not None | |
return image | |
def close(self): | |
try: | |
self._env.close() | |
except Exception: | |
pass | |
def _flatten(self, nest, prefix=None): | |
result = {} | |
for key, value in nest.items(): | |
key = prefix + '/' + key if prefix else key | |
if isinstance(value, gymnasium.spaces.Dict): | |
value = value.spaces | |
if isinstance(value, dict): | |
result.update(self._flatten(value, key)) | |
else: | |
result[key] = value | |
return result | |
def _unflatten(self, flat): | |
result = {} | |
for key, value in flat.items(): | |
parts = key.split('/') | |
node = result | |
for part in parts[:-1]: | |
if part not in node: | |
node[part] = {} | |
node = node[part] | |
node[parts[-1]] = value | |
return result | |
def _convert(self, space): | |
if hasattr(space, 'n'): | |
return embodied.Space(np.int32, (), 0, space.n) | |
return embodied.Space(space.dtype, space.shape, space.low, space.high) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import cast | |
import gymnasium | |
from minigrid.wrappers import FullyObsWrapper, ObservationWrapper | |
from dreamerv3.embodied.envs.from_gymnasium import FromGymnasium | |
class HideMission(ObservationWrapper): | |
"""Remove the 'mission' string from the observation.""" | |
def __init__(self, env): | |
super().__init__(env) | |
obs_space = cast(gymnasium.spaces.Dict, self.observation_space) | |
obs_space.spaces.pop('mission') | |
def observation(self, observation: dict): | |
observation.pop('mission') | |
return observation | |
class Minigrid(FromGymnasium): | |
def __init__(self, task: str, fully_observable: bool, hide_mission: bool): | |
env = gymnasium.make(f"MiniGrid-{task}-v0", render_mode="rgb_array") | |
if fully_observable: | |
env = FullyObsWrapper(env) | |
if hide_mission: | |
env = HideMission(env) | |
super().__init__(env=env) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment