Skip to content

Instantly share code, notes, and snippets.

@dtsaras
Last active April 23, 2024 07:02
Show Gist options
  • Save dtsaras/f321aed253a64e4849ce95bd232d1635 to your computer and use it in GitHub Desktop.
Save dtsaras/f321aed253a64e4849ce95bd232d1635 to your computer and use it in GitHub Desktop.
class HERSubGoalSampler(Transform):
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
"""
def __init__(
self,
num_samples: int = 4,
subgoal_idx_key: str = "subgoal_idx",
strategy: str = "future"
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.num_samples = num_samples
self.subgoal_idx_key = subgoal_idx_key
self.strategy = strategy
def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_len = trajectories.shape
if self.strategy == "last":
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size)
else:
subgoal_idxs = []
for i in range(batch_size):
subgoal_idxs.append(
TensorDict(
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]},
batch_size=torch.Size(),
)
)
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True)
class HERSubGoalAssigner(Transform):
"""This module assigns the subgoal to the trajectory according to a given subgoal index.
Args:
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
"""
def __init__(
self,
achieved_goal_key: str = "achieved_goal",
desired_goal_key: str = "desired_goal",
):
self.achieved_goal_key = achieved_goal_key
self.desired_goal_key = desired_goal_key
def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase:
batch_size, trajectory_len = trajectories.shape
for i in range(batch_size):
subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key]
desired_goal_shape = trajectories[i][self.desired_goal_key].shape
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape)
trajectories[i][subgoals_idxs[i]]["next", "done"] = True
# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True
return trajectories
class HERRewardTransform(Transform):
"""This module assigns the reward to the trajectory according to the new subgoal.
Args:
reward_name (str): The key to the reward. Defaults to "reward".
"""
def __init__(
self
):
pass
def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
return trajectories
class HindsightExperienceReplayTransform(Transform):
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
This module is a wrapper that includes the following modules:
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal.
Args:
SubGoalSampler (Transform):
SubGoalAssigner (Transform):
RewardTransform (Transform):
"""
def __init__(
self,
SubGoalSampler: Transform = HERSubGoalSampler(),
SubGoalAssigner: Transform = HERSubGoalAssigner(),
RewardTransform: Transform = HERRewardTransform(),
assign_subgoal_idxs: bool = False,
):
super().__init__(
in_keys=None,
in_keys_inv=None,
out_keys_inv=None,
)
self.SubGoalSampler = SubGoalSampler
self.SubGoalAssigner = SubGoalAssigner
self.RewardTransform = RewardTransform
self.assign_subgoal_idxs = assign_subgoal_idxs
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
augmentation_td = self.her_augmentation(tensordict)
return torch.cat([tensordict, augmentation_td], dim=0)
def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
return self.her_augmentation(tensordict)
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return tensordict
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(self.ENV_ERR)
def her_augmentation(self, trajectories: TensorDictBase):
if len(trajectories.shape) == 1:
trajectories = trajectories.unsqueeze(0)
batch_size, trajectory_length = trajectories.shape
new_trajectories = trajectories.clone(True)
# Sample subgoal indices
subgoal_idxs = self.SubGoalSampler(new_trajectories)
# Create new trajectories
augmented_trajectories = []
list_idxs = []
for i in range(batch_size):
idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key]
if "masks" in subgoal_idxs.keys():
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]]
list_idxs.append(idxs.unsqueeze(-1))
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True)
if self.assign_subgoal_idxs:
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length)
augmented_trajectories.append(new_traj)
augmented_trajectories = torch.cat(augmented_trajectories, dim=0)
associated_idxs = torch.cat(list_idxs, dim=0)
# Assign subgoals to the new trajectories
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs)
# Adjust the rewards based on the new subgoals
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories)
return augmented_trajectories
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment