Created
April 14, 2023 04:18
-
-
Save DuaneNielsen/534eecb6602828f4ed0f755cd7e1f385 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import torch | |
from torch import tensor | |
from tensordict import TensorDict | |
from torchrl.data import CompositeSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, DiscreteTensorSpec, \ | |
UnboundedContinuousTensorSpec | |
from torchrl.envs import ( | |
EnvBase, | |
Transform, | |
TransformedEnv, | |
) | |
from torchrl.envs.utils import check_env_specs, step_mdp | |
from torchrl.envs.transforms.transforms import _apply_to_composite | |
from torch.nn.functional import interpolate | |
from torchvision.utils import make_grid | |
from math import prod | |
""" | |
A minimal stateless vectorized gridworld in pytorch rl | |
Action space: (0, 1, 2, 3) -> N, E, S, W | |
Features | |
walls | |
1 time pickup rewards or penalties | |
done tiles | |
outputs a fully observable RGB image | |
look at the gen_params function to setup the world | |
example of configuring and performing a rollout at bottom | |
After implementing this, I think that pytorch RL doesn't really support | |
vectorized stateless environments. Which is OK | |
Maybe I will write an even simpler stateless non-vectorized version of this gridworld | |
""" | |
# N/S is reversed as y-axis in images is reversed | |
action_vec = [ | |
tensor([0, -1]), # N | |
tensor([1, 0]), # E | |
tensor([0, 1]), # S | |
tensor([-1, 0]) # W | |
] | |
action_vec = torch.stack(action_vec) | |
yellow = tensor([255, 255, 0], dtype=torch.uint8) | |
red = tensor([255, 0, 0], dtype=torch.uint8) | |
green = tensor([0, 255, 0], dtype=torch.uint8) | |
pink = tensor([255, 0, 255], dtype=torch.uint8) | |
violet = tensor([226, 43, 138], dtype=torch.uint8) | |
white = tensor([255, 255, 255], dtype=torch.uint8) | |
def _step(state): | |
# make our life easier by creating a view with a single leading dim | |
state_flat = state.view(prod(state.shape)) | |
batch_range = torch.arange(state_flat.size(0)) | |
# move player position checking for collisions | |
next_player_pos = state_flat['player_pos'] + action_vec[state_flat['action'][:, 0]].to(state.device) | |
next_player_grid = torch.zeros_like(state_flat['wall_tiles'], dtype=torch.bool, device=state.device) | |
next_player_grid[batch_range, next_player_pos[:, 0], next_player_pos[:, 1]] = True | |
collide_wall = torch.logical_and(next_player_grid, state_flat['wall_tiles'] == 1).any(-1).any(-1) | |
player_pos = torch.where(collide_wall[..., None], state_flat['player_pos'], next_player_pos) | |
player_pos_mask = torch.zeros_like(state_flat['wall_tiles'], dtype=torch.bool, device=state.device) | |
player_pos_mask[batch_range, player_pos[:, 0], player_pos[:, 1]] = True | |
player_pos = player_pos.reshape(state['player_pos'].shape) | |
player_pos_mask = player_pos_mask.reshape(state['wall_tiles'].shape) | |
# pickup any rewards | |
reward = state['reward_tiles'][player_pos_mask] | |
state['reward_tiles'][player_pos_mask] = 0. | |
# set done flag if hit done tile | |
done = state['done_tiles'][player_pos_mask] | |
next = { | |
'player_pos': player_pos, | |
'wall_tiles': state['wall_tiles'], | |
'reward_tiles': state['reward_tiles'], | |
'done_tiles': state['done_tiles'], | |
'reward': reward, | |
'done': done | |
} | |
return TensorDict({'next': next}, state.shape) | |
def _reset(self, state=None): | |
batch_size = state.shape if state is not None else [] | |
if state is None or state.is_empty(): | |
state = self.gen_params(batch_size) | |
return state | |
def gen_params(batch_size=None): | |
walls = tensor([ | |
[1, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1], | |
[1, 0, 0, 0, 1], | |
[1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 1], | |
], dtype=torch.uint8) | |
rewards = tensor([ | |
[0, 0, 0, 0, 0], | |
[0, 1, 1, -1, 0], | |
[0, 1, 0, 1, 0], | |
[0, -1, 1, 1, 0], | |
[0, 0, 0, 0, 0], | |
], dtype=torch.float32) | |
dones = tensor([ | |
[0, 0, 0, 0, 0], | |
[0, 0, 0, 1, 0], | |
[0, 0, 0, 0, 0], | |
[0, 1, 0, 0, 0], | |
[0, 0, 0, 0, 0], | |
], dtype=torch.bool) | |
player_pos = tensor([2, 2], dtype=torch.int64) | |
observation = { | |
"player_pos": player_pos, | |
"wall_tiles": walls, | |
"reward_tiles": rewards, | |
"done_tiles": dones | |
} | |
td = TensorDict(observation, batch_size=[]) | |
if batch_size: | |
td = td.expand(batch_size).contiguous() | |
return td | |
def _make_spec(self, td_params): | |
batch_size = td_params.shape | |
self.observation_spec = CompositeSpec( | |
wall_tiles=BoundedTensorSpec( | |
minimum=0, | |
maximum=1, | |
shape=torch.Size((*batch_size, 5, 5)), | |
dtype=torch.uint8, | |
), | |
reward_tiles=UnboundedContinuousTensorSpec( | |
shape=torch.Size((*batch_size, 5, 5)), | |
dtype=torch.float32, | |
), | |
done_tiles=BoundedTensorSpec( | |
minimum=0, | |
maximum=1, | |
shape=torch.Size((*batch_size, 5, 5)), | |
dtype=torch.bool, | |
), | |
player_pos=UnboundedDiscreteTensorSpec( | |
shape=torch.Size((*batch_size, 2,)), | |
dtype=torch.int64 | |
), | |
shape=torch.Size((*batch_size,)) | |
) | |
self.input_spec = self.observation_spec.clone() | |
self.action_spec = DiscreteTensorSpec(4, shape=torch.Size((*batch_size, 1))) | |
self.reward_spec = UnboundedContinuousTensorSpec(shape=torch.Size((*batch_size, 1))) | |
def _set_seed(self, seed: Optional[int]): | |
rng = torch.manual_seed(seed) | |
self.rng = rng | |
class Gridworld(EnvBase): | |
metadata = { | |
"render_modes": ["human", ""], | |
"render_fps": 30 | |
} | |
batch_locked = False | |
def __init__(self, td_params=None, device="cpu", batch_size=None): | |
if td_params is None: | |
td_params = self.gen_params(batch_size) | |
batch_size = [] if batch_size is None else batch_size | |
super().__init__(device=device, batch_size=batch_size) | |
self._make_spec(td_params) | |
self.shape = batch_size | |
gen_params = staticmethod(gen_params) | |
_make_spec = _make_spec | |
_reset = _reset | |
_step = staticmethod(_step) | |
_set_seed = _set_seed | |
class RGBFullObsTransform(Transform): | |
def forward(self, tensordict): | |
return self._call(tensordict) | |
def _call(self, td): | |
td_flat = td.view(prod(td.batch_size)) | |
batch_range = torch.arange(td_flat.size(0)) | |
player_pos = td_flat['player_pos'] | |
walls = td_flat['wall_tiles'] | |
rewards = td_flat['reward_tiles'] | |
grid = TensorDict({'image': torch.zeros(*walls.shape, 3, dtype=torch.uint8)}, batch_size=td_flat.batch_size) | |
x, y = player_pos[:, 0], player_pos[:, 1] | |
grid['image'][walls == 1] = white | |
grid['image'][rewards > 0] = green | |
grid['image'][rewards < 0] = red | |
grid['image'][batch_range, x, y, :] = yellow | |
grid['image'] = grid['image'].permute(0, 3, 1, 2) | |
observation = interpolate(grid['image'], size=[64, 64]).squeeze(0) | |
return TensorDict({ | |
"observation": observation, | |
**td | |
}, batch_size=td.batch_size) | |
@_apply_to_composite | |
def transform_observation_spec(self, observation_spec): | |
return BoundedTensorSpec( | |
minimum=0, | |
maximum=255, | |
shape=torch.Size((3, 64, 64)), | |
dtype=torch.uint8, | |
device=observation_spec.device | |
) | |
if __name__ == '__main__': | |
from matplotlib import pyplot as plt | |
def simple_rollout(steps=100, batch_size=None): | |
batch_size = [1] if batch_size is None else batch_size | |
# preallocate: | |
data = TensorDict({}, [steps, *batch_size]) | |
# reset | |
_data = env.gen_params(batch_size=batch_size) | |
_data = env.reset(_data) | |
for i in range(steps): | |
_data["action"] = env.action_spec.rand(shape=batch_size) | |
_data = env.step(_data) | |
data[i] = _data | |
_data = step_mdp(_data, keep_other=True) | |
return data | |
env = Gridworld() | |
check_env_specs(env) | |
env = TransformedEnv( | |
env, | |
RGBFullObsTransform(in_keys=['player_pos', 'walls'], out_keys=['observation']) | |
) | |
check_env_specs(env) | |
data = simple_rollout(batch_size=[64]) | |
fig, ax = plt.subplots(1) | |
img_plt = ax.imshow(make_grid(data[0]['observation']).permute(1, 2, 0)) | |
for timestep in data: | |
x = make_grid(timestep['observation']).permute(1, 2, 0) | |
img_plt.set_data(x) | |
plt.pause(1.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment