Last active
April 8, 2024 08:32
-
-
Save masterdezign/47b3c6172dd1624bb9a7ef23cbc79c8c to your computer and use it in GitHub Desktop.
Recurrent replay buffer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
from typing import Any, Dict, Generator, List, Optional, Union | |
from typing import NamedTuple, Tuple | |
from gymnasium import spaces | |
import numpy as np | |
import torch as th | |
from stable_baselines3.common.buffers import BaseBuffer | |
from stable_baselines3.common.vec_env import VecNormalize | |
class RecurrentReplayBufferSamples(NamedTuple): | |
observations: th.Tensor | |
actions: th.Tensor | |
rewards: th.Tensor | |
dones: th.Tensor | |
mask: th.Tensor | |
class RecurrentReplayBuffer(BaseBuffer): | |
def __init__( | |
self, | |
buffer_size: int, | |
observation_space: spaces.Space, | |
action_space: spaces.Space, | |
chunk_len: int = 120, | |
overlap: int = 40, | |
n_envs: int = 1, | |
**kwargs | |
): | |
""" | |
:param buffer_size: Max number of element in the buffer | |
:param observation_space: Observation space | |
:param action_space: Action space | |
:param chunk_len: total number of timesteps to store in each chunk: | |
l + m, for example, l = 40 is the burn-in length and m = 80 is the | |
"useful" length of the chunk [1] | |
:param overlap: overlap length between stored chunks [1] | |
[1] Kapturowski, Steven, et al. "Recurrent experience replay in distributed | |
reinforcement learning." International Conference on Learning | |
Representations. 2019. | |
""" | |
# This might be something to rethink in the future: | |
# See | |
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/master/sb3_contrib/common/recurrent/buffers.py | |
# for the reference | |
assert n_envs == 1, "RecurrentReplayBuffer does not support multiple envs" | |
super().__init__( | |
buffer_size, observation_space, action_space, n_envs=n_envs, **kwargs | |
) | |
self.obs_dim = observation_space.shape[0] | |
self.act_dim = action_space.shape[0] | |
self.chunk_len = chunk_len | |
self.overlap = overlap | |
self.reset() | |
def reset(self) -> None: | |
""" | |
Reset the buffer. | |
""" | |
# Store chunks of episodes | |
# chunk_len + 1 because we store the final next observation in the chunk | |
self.o = np.zeros( | |
(self.buffer_size, self.chunk_len + 1, self.obs_dim), dtype=np.float32 | |
) | |
self.a = np.zeros( | |
(self.buffer_size, self.chunk_len, self.act_dim), dtype=np.float32 | |
) | |
self.r = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) | |
self.d = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) | |
# Mask: Valid step = 1, no record = 0 | |
self.m = np.zeros((self.buffer_size, self.chunk_len, 1), dtype=np.float32) | |
# self.pos (from parent class) is the position of the episode chunk in | |
# the buffer (a "row counter"). | |
# self.time_pos is the position of the timestep in the chunk (a "column | |
# counter"). | |
self.time_pos = 0 | |
super().reset() | |
def add( | |
self, | |
obs: np.ndarray, | |
next_obs: np.ndarray, | |
action: np.ndarray, | |
reward: np.ndarray, | |
done: np.ndarray, | |
) -> None: | |
# Copy to avoid modification by reference | |
self.o[self.pos, self.time_pos] = np.array(obs).copy() | |
self.o[self.pos, self.time_pos + 1] = np.array(next_obs).copy() | |
self.a[self.pos, self.time_pos] = np.array(action).copy() | |
self.r[self.pos, self.time_pos] = np.array(reward).copy() | |
self.d[self.pos, self.time_pos] = np.array(done).copy() | |
self.m[self.pos, self.time_pos] = 1 | |
# Update the time position in the chunk | |
self.time_pos += 1 | |
# Chunk just finished | |
end_of_chunk = self.time_pos == self.chunk_len | |
# Special cases: | |
# If the chunk is complete or the episode is done: | |
# - New chunk position in the buffer (row counter) | |
# - Reset the time position in the chunk (column counter) | |
if end_of_chunk or done: # n_envs == 1 | |
# Check whether the buffer is going to be full | |
if self.pos == self.buffer_size - 1: | |
self.full = True | |
# Start a new chunk by updating the position in the buffer | |
self.pos = (self.pos + 1) % self.buffer_size | |
# Overlap handling on the end of chunk: Copy the last `overlap` | |
# timesteps to the beginning of the next chunk. | |
# If its done by the end of chunk, nothing to do. | |
if end_of_chunk and not done: | |
self.o[self.pos, : self.overlap + 1] = self.o[ | |
self.pos - 1, -(self.overlap + 1) : | |
] | |
self.a[self.pos, : self.overlap] = self.a[self.pos - 1, -self.overlap :] | |
self.r[self.pos, : self.overlap] = self.r[self.pos - 1, -self.overlap :] | |
self.d[self.pos, : self.overlap] = self.d[self.pos - 1, -self.overlap :] | |
# Fill the mask with 1 for the valid steps | |
self.m[self.pos, : self.overlap] = 1 | |
self.time_pos = self.overlap | |
if done: # n_envs == 1 | |
# Move time position to the beginning of the chunk | |
self.time_pos = 0 | |
def _get_samples( | |
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None | |
) -> RecurrentReplayBufferSamples: | |
""" | |
:param batch_inds: | |
:param env: | |
:return: A batch of chunks of episodes | |
""" | |
o = self.o[batch_inds] | |
a = self.a[batch_inds] | |
r = self.r[batch_inds] | |
d = self.d[batch_inds] | |
m = self.m[batch_inds] | |
o = self._normalize_obs(o, env) | |
data = (o, a, r, d, m) | |
return RecurrentReplayBufferSamples(*tuple(map(self.to_torch, data))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
from buffers import RecurrentReplayBuffer | |
def random_sample(done=False, prev_obs=None): | |
return ( | |
prev_obs if prev_obs is not None else np.random.rand(3).astype(np.float32), | |
np.random.rand(3).astype(np.float32), | |
np.random.rand(1).astype(np.float32), | |
np.random.rand(), | |
done, | |
) | |
def _init_buffer(buffer_size=8, elements=0, chunk_len=4, overlap=1, env=None): | |
if env is not None: | |
buffer = RecurrentReplayBuffer( | |
buffer_size, | |
env.observation_space, | |
env.action_space, | |
chunk_len=chunk_len, | |
overlap=overlap, | |
) | |
o, _ = env.reset() | |
else: | |
assert False, "TODO: Use mock observation and action spaces" | |
for _ in range(elements): | |
if env is not None: | |
a = env.action_space.sample() | |
o2, r, term, trunc, _ = env.step(a) | |
# Fill the buffer with random samples | |
buffer.add(o, o2, a, r, term) | |
o = o2 | |
else: | |
buffer.add(*random_sample()) | |
return buffer | |
def test_add(): | |
env = gym.make("Pendulum-v1") | |
buffer_size = 8 | |
chunk_len = 5 | |
overlap = 3 | |
buffer = _init_buffer( | |
buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env | |
) | |
assert np.abs(buffer.o).sum() == 0, "Buffer should be empty" | |
assert np.abs(buffer.m).sum() == 0, "Buffer should be empty" | |
assert buffer.pos == 0, "Buffer should be empty" | |
assert buffer.time_pos == 0, "Buffer should be empty" | |
sa = random_sample() | |
buffer.add(*sa) | |
assert buffer.pos == 0, "Position should have not changed" | |
assert buffer.time_pos == 1, "Time position should have increased" | |
assert np.allclose(buffer.o[0][0], sa[0]), "0 Observations should be recorded" | |
assert np.allclose(buffer.o[0][1], sa[1]), "0 Next observations should be recorded" | |
assert np.allclose(buffer.a[0][0], sa[2]), "0 Actions should be recorded" | |
assert np.allclose(buffer.r[0][0], sa[3]), "0 Rewards should be recorded" | |
assert np.allclose(buffer.d[0][0], False), "0 Dones should be recorded" | |
assert buffer.m[0][0] == 1, "Mask should be updated" | |
assert np.abs(buffer.m).sum() == 1, "Mask should be updated" | |
sa = random_sample(prev_obs=sa[1]) | |
buffer.add(*sa) | |
assert buffer.pos == 0, "Position should have not changed" | |
assert buffer.time_pos == 2, "Time position should have increased" | |
assert np.allclose(buffer.o[0][1], sa[0]), "1 Observations should be recorded" | |
assert np.allclose(buffer.o[0][2], sa[1]), "1 Next observations should be recorded" | |
assert np.allclose(buffer.a[0][1], sa[2]), "1 Actions should be recorded" | |
assert np.allclose(buffer.r[0][1], sa[3]), "1 Rewards should be recorded" | |
assert np.allclose(buffer.d[0][1], False), "1 Dones should be recorded" | |
assert buffer.m[0][1] == 1, "Mask should be updated" | |
assert np.abs(buffer.m).sum() == 2, "Mask should be updated" | |
# Simulate done episode | |
sa = random_sample(done=True, prev_obs=sa[1]) | |
buffer.add(*sa) | |
assert buffer.pos == 1, "New chunk should have started" | |
assert buffer.time_pos == 0, "Time position should have been reset" | |
assert np.allclose(buffer.o[0][2], sa[0]), "2 Observations should be recorded" | |
assert np.allclose(buffer.o[0][3], sa[1]), "2 Next observations should be recorded" | |
assert np.allclose(buffer.a[0][2], sa[2]), "2 Actions should be recorded" | |
assert np.allclose(buffer.r[0][2], sa[3]), "2 Rewards should be recorded" | |
assert np.allclose(buffer.d[0][2], True), "2 Dones should be recorded" | |
assert buffer.m[0][2] == 1, "Mask should be updated" | |
assert np.abs(buffer.m).sum() == 3, "Mask should be updated" | |
print("New chunk:\n", buffer.o[1]) | |
# Test automatic chunking | |
for i in range(chunk_len - 1): | |
sa = random_sample(prev_obs=sa[1]) | |
buffer.add(*sa) | |
assert buffer.pos == 1, "Position should have not changed" | |
assert buffer.time_pos == i + 1, "Time position should have increased" | |
assert np.allclose(buffer.o[1][i], sa[0]), "Observations should be recorded" | |
assert np.allclose( | |
buffer.o[1][i + 1], sa[1] | |
), "Next observations should be recorded" | |
assert np.allclose(buffer.a[1][i], sa[2]), "Actions should be recorded" | |
assert np.allclose(buffer.r[1][i], sa[3]), "Rewards should be recorded" | |
assert np.allclose(buffer.d[1][i], False), "Dones should be recorded" | |
assert buffer.m[1][i] == 1, "Mask should be updated" | |
assert np.abs(buffer.m).sum() == i + 4, "Mask should be updated" | |
print("Current chunk:\n", buffer.o[1]) | |
# Here we should start a new chunk | |
sa2 = random_sample(prev_obs=sa[1]) | |
buffer.add(*sa2) | |
print("Prev obs", sa2[0]) | |
print("New obs", sa2[1]) | |
print("Current chunk:\n", buffer.o[1]) | |
assert buffer.full == False, "Buffer should not be full" | |
assert buffer.pos == 2, "New chunk should have started" | |
assert buffer.time_pos == overlap, "Time position should have been moved to overlap" | |
assert np.allclose(buffer.o[1][chunk_len - 1], sa2[0]), "End of previous chunk" | |
assert np.allclose( | |
buffer.o[1][chunk_len], sa2[1] | |
), "End of previous chunk - new obs" | |
assert buffer.m[1][chunk_len - 1] == 1, "Mask should be updated" | |
print("New chunk:\n", buffer.o[2]) | |
assert np.allclose( | |
buffer.o[2][1], sa[0] | |
), "Overlap: Old observations should have been preserved" | |
assert np.allclose( | |
buffer.o[2][2], sa2[0] | |
), "Overlap: Current observations should be recorded" | |
assert np.allclose( | |
buffer.o[2][3], sa2[1] | |
), "Overlap: Next observations should be recorded" | |
assert buffer.m[2][0] == 1, "Overlap: Mask should be updated" | |
assert buffer.m[2][1] == 1, "Overlap: Mask should be updated" | |
assert buffer.m[2][2] == 1, "Overlap: Mask should be updated" | |
assert buffer.m[2][3] == 0, "Mask should remain 0" | |
sa2 = random_sample(prev_obs=sa2[1]) | |
buffer.add(*sa2) | |
print("Current chunk:\n", buffer.o[2]) | |
assert buffer.pos == 2, "Position should remain the same" | |
assert buffer.time_pos == overlap + 1, "Time position should have been increased" | |
assert np.allclose( | |
buffer.o[2][3], sa2[0] | |
), "Overlap: Current observations should be recorded" | |
assert np.allclose( | |
buffer.o[2][4], sa2[1] | |
), "Overlap: Next observations should be recorded" | |
# Edge case test: end of chunk and done | |
sa2 = random_sample(prev_obs=sa2[1], done=True) | |
buffer.add(*sa2) | |
print("Current chunk:\n", buffer.o[2]) | |
assert buffer.pos == 3, "Position should have increased" | |
assert buffer.time_pos == 0, "Time position should have been reset" | |
assert buffer.full == False, "Buffer should not be full" | |
print("New empty chunk:\n", buffer.o[3]) | |
# Fill the buffer until it's full | |
for test_pos in range(4, buffer_size): | |
print("test_pos", test_pos) | |
sa2 = random_sample(done=True) | |
buffer.add(*sa2) | |
assert np.allclose( | |
buffer.o[test_pos - 1][0], sa2[0] | |
), "Overlap: Current observations should be recorded" | |
assert np.allclose( | |
buffer.o[test_pos - 1][1], sa2[1] | |
), "Overlap: Next observations should be recorded" | |
assert buffer.pos == test_pos, "Position should have increased" | |
assert buffer.time_pos == 0, "Time position should have been reset" | |
assert buffer.full == False, "Buffer should not be full" | |
sa2 = random_sample(done=True) | |
buffer.add(*sa2) | |
assert buffer.full == True, "Buffer should be full" | |
def test_sample(): | |
env = gym.make("Pendulum-v1") | |
obs_dim = env.observation_space.shape[0] | |
act_dim = env.action_space.shape[0] | |
buffer_size = 200 | |
chunk_len = 40 | |
overlap = 5 | |
buffer = _init_buffer( | |
elements=100, | |
buffer_size=buffer_size, chunk_len=chunk_len, overlap=overlap, env=env | |
) | |
batch = buffer.sample(32) | |
print("observations.shape", batch.observations.shape) | |
print("rewards.shape", batch.rewards.shape) | |
assert len(batch.observations) == 32, "Batch should have 32 elements" | |
assert batch.observations.shape == (32, chunk_len + 1, obs_dim), "Observations shape should be (32, 40 + 1, obs_dim)" | |
assert batch.actions.shape == (32, chunk_len, act_dim), "Actions shape should be (32, 40, act_dim)" | |
assert batch.rewards.shape == (32, chunk_len, 1), "Rewards shape should be (32, 40, 1)" | |
assert batch.dones.shape == (32, chunk_len, 1), "Dones shape should be (32, 40, 1)" | |
assert batch.mask.shape == (32, chunk_len, 1), "Mask shape should be (32, 40, 1)" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment