Skip to content

Instantly share code, notes, and snippets.

@qxcv
Last active August 21, 2024 11:54
Show Gist options
  • Save qxcv/e8641342c102c2aa714c9caeca724101 to your computer and use it in GitHub Desktop.
Save qxcv/e8641342c102c2aa714c9caeca724101 to your computer and use it in GitHub Desktop.
Gymnasium envs with Dreamer v3, along with a Minigrid example
# from_gym.py adapted to work with Gymnasium. Differences:
#
# - gym.* -> gymnasium.*
# - Deals with .step() returning a tuple of (obs, reward, terminated, truncated,
# info) rather than (obs, reward, done, info).
# - Also deals with .reset() returning a tuple of (obs, info) rather than just
# obs.
# - Passes render_mode='rgb_array' to gymnasium.make() rather than .render().
# - A bunch of minor/irrelevant type checking changes that stopped pyright from
# complaining (these have no functional purpose, I'm just a completionist who
# doesn't like red squiggles).
import functools
from typing import Any, Generic, TypeVar, Union, cast, Dict
import embodied
import gymnasium
import numpy as np
U = TypeVar('U')
V = TypeVar('V')
class FromGymnasium(embodied.Env, Generic[U, V]):
def __init__(self, env: Union[str, gymnasium.Env[U, V]], obs_key='image', act_key='action', **kwargs):
if isinstance(env, str):
self._env: gymnasium.Env[U, V] = gymnasium.make(env, render_mode="rgb_array", **kwargs)
else:
assert not kwargs, kwargs
assert env.render_mode == "rgb_array", f"render_mode must be rgb_array, got {self._env.render_mode}"
self._env = env
self._obs_dict = hasattr(self._env.observation_space, 'spaces')
self._act_dict = hasattr(self._env.action_space, 'spaces')
self._obs_key = obs_key
self._act_key = act_key
self._done = True
self._info = None
@property
def info(self):
return self._info
@functools.cached_property
def obs_space(self):
if self._obs_dict:
# cast is here to stop type checkers from complaining (we already check
# that .spaces attr exists in __init__ as a proxy for the type check)
obs_space = cast(gymnasium.spaces.Dict, self._env.observation_space)
spaces = obs_space.spaces
else:
spaces = {self._obs_key: self._env.observation_space}
spaces = {k: self._convert(v) for k, v in spaces.items()}
return {
**spaces,
'reward': embodied.Space(np.float32),
'is_first': embodied.Space(bool),
'is_last': embodied.Space(bool),
'is_terminal': embodied.Space(bool),
}
@functools.cached_property
def act_space(self):
if self._act_dict:
act_space = cast(gymnasium.spaces.Dict, self._env.action_space)
spaces = act_space.spaces
else:
spaces = {self._act_key: self._env.action_space}
spaces = {k: self._convert(v) for k, v in spaces.items()}
spaces['reset'] = embodied.Space(bool)
return spaces
def step(self, action):
if action['reset'] or self._done:
self._done = False
# we don't bother setting ._info here because it gets set below, once we
# take the next .step()
obs, _ = self._env.reset()
return self._obs(obs, 0.0, is_first=True)
if self._act_dict:
gymnasium_action = cast(V, self._unflatten(action))
else:
gymnasium_action = cast(V, action[self._act_key])
obs, reward, terminated, truncated, self._info = self._env.step(gymnasium_action)
self._done = terminated or truncated
return self._obs(
obs, reward,
is_last=bool(self._done),
is_terminal=bool(self._info.get('is_terminal', self._done)))
def _obs(
self, obs, reward, is_first=False, is_last=False, is_terminal=False):
if not self._obs_dict:
obs = {self._obs_key: obs}
obs = self._flatten(obs)
np_obs: Dict[str, Any] = {k: np.asarray(v) for k, v in obs.items()}
np_obs.update(
reward=np.float32(reward),
is_first=is_first,
is_last=is_last,
is_terminal=is_terminal)
return np_obs
def render(self):
image = self._env.render()
assert image is not None
return image
def close(self):
try:
self._env.close()
except Exception:
pass
def _flatten(self, nest, prefix=None):
result = {}
for key, value in nest.items():
key = prefix + '/' + key if prefix else key
if isinstance(value, gymnasium.spaces.Dict):
value = value.spaces
if isinstance(value, dict):
result.update(self._flatten(value, key))
else:
result[key] = value
return result
def _unflatten(self, flat):
result = {}
for key, value in flat.items():
parts = key.split('/')
node = result
for part in parts[:-1]:
if part not in node:
node[part] = {}
node = node[part]
node[parts[-1]] = value
return result
def _convert(self, space):
if hasattr(space, 'n'):
return embodied.Space(np.int32, (), 0, space.n)
return embodied.Space(space.dtype, space.shape, space.low, space.high)
from typing import cast
import gymnasium
from minigrid.wrappers import FullyObsWrapper, ObservationWrapper
from dreamerv3.embodied.envs.from_gymnasium import FromGymnasium
class HideMission(ObservationWrapper):
"""Remove the 'mission' string from the observation."""
def __init__(self, env):
super().__init__(env)
obs_space = cast(gymnasium.spaces.Dict, self.observation_space)
obs_space.spaces.pop('mission')
def observation(self, observation: dict):
observation.pop('mission')
return observation
class Minigrid(FromGymnasium):
def __init__(self, task: str, fully_observable: bool, hide_mission: bool):
env = gymnasium.make(f"MiniGrid-{task}-v0", render_mode="rgb_array")
if fully_observable:
env = FullyObsWrapper(env)
if hide_mission:
env = HideMission(env)
super().__init__(env=env)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment