Last active
April 23, 2024 07:02
-
-
Save dtsaras/f321aed253a64e4849ce95bd232d1635 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HERSubGoalSampler(Transform): | |
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | |
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | |
Args: | |
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | |
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | |
""" | |
def __init__( | |
self, | |
num_samples: int = 4, | |
subgoal_idx_key: str = "subgoal_idx", | |
strategy: str = "future" | |
): | |
super().__init__( | |
in_keys=None, | |
in_keys_inv=None, | |
out_keys_inv=None, | |
) | |
self.num_samples = num_samples | |
self.subgoal_idx_key = subgoal_idx_key | |
self.strategy = strategy | |
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | |
if len(trajectories.shape) == 1: | |
trajectories = trajectories.unsqueeze(0) | |
batch_size, trajectory_len = trajectories.shape | |
if self.strategy == "last": | |
return TensorDict({"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size) | |
else: | |
subgoal_idxs = [] | |
for i in range(batch_size): | |
subgoal_idxs.append( | |
TensorDict( | |
{"subgoal_idx": (torch.randperm(trajectory_len-2)+1)[:self.num_samples]}, | |
batch_size=torch.Size(), | |
) | |
) | |
return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True) | |
class HERSubGoalAssigner(Transform): | |
"""This module assigns the subgoal to the trajectory according to a given subgoal index. | |
Args: | |
subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx". | |
subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal". | |
""" | |
def __init__( | |
self, | |
achieved_goal_key: str = "achieved_goal", | |
desired_goal_key: str = "desired_goal", | |
): | |
self.achieved_goal_key = achieved_goal_key | |
self.desired_goal_key = desired_goal_key | |
def forward(self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor) -> TensorDictBase: | |
batch_size, trajectory_len = trajectories.shape | |
for i in range(batch_size): | |
subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key] | |
desired_goal_shape = trajectories[i][self.desired_goal_key].shape | |
trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape) | |
trajectories[i][subgoals_idxs[i]]["next", "done"] = True | |
# trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True | |
return trajectories | |
class HERRewardTransform(Transform): | |
"""This module assigns the reward to the trajectory according to the new subgoal. | |
Args: | |
reward_name (str): The key to the reward. Defaults to "reward". | |
""" | |
def __init__( | |
self | |
): | |
pass | |
def forward(self, trajectories: TensorDictBase) -> TensorDictBase: | |
return trajectories | |
class HindsightExperienceReplayTransform(Transform): | |
"""Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones. | |
This module is a wrapper that includes the following modules: | |
- SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory. | |
- SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index. | |
- RewardTransform: Assigns the reward to the trajectory according to the new subgoal. | |
Args: | |
SubGoalSampler (Transform): | |
SubGoalAssigner (Transform): | |
RewardTransform (Transform): | |
""" | |
def __init__( | |
self, | |
SubGoalSampler: Transform = HERSubGoalSampler(), | |
SubGoalAssigner: Transform = HERSubGoalAssigner(), | |
RewardTransform: Transform = HERRewardTransform(), | |
assign_subgoal_idxs: bool = False, | |
): | |
super().__init__( | |
in_keys=None, | |
in_keys_inv=None, | |
out_keys_inv=None, | |
) | |
self.SubGoalSampler = SubGoalSampler | |
self.SubGoalAssigner = SubGoalAssigner | |
self.RewardTransform = RewardTransform | |
self.assign_subgoal_idxs = assign_subgoal_idxs | |
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: | |
augmentation_td = self.her_augmentation(tensordict) | |
return torch.cat([tensordict, augmentation_td], dim=0) | |
def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor: | |
return self.her_augmentation(tensordict) | |
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |
return tensordict | |
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: | |
raise ValueError(self.ENV_ERR) | |
def her_augmentation(self, trajectories: TensorDictBase): | |
if len(trajectories.shape) == 1: | |
trajectories = trajectories.unsqueeze(0) | |
batch_size, trajectory_length = trajectories.shape | |
new_trajectories = trajectories.clone(True) | |
# Sample subgoal indices | |
subgoal_idxs = self.SubGoalSampler(new_trajectories) | |
# Create new trajectories | |
augmented_trajectories = [] | |
list_idxs = [] | |
for i in range(batch_size): | |
idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key] | |
if "masks" in subgoal_idxs.keys(): | |
idxs = idxs[subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]] | |
list_idxs.append(idxs.unsqueeze(-1)) | |
new_traj = new_trajectories[i].expand((idxs.numel(),trajectory_length)).clone(True) | |
if self.assign_subgoal_idxs: | |
new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(-1).repeat(1, trajectory_length) | |
augmented_trajectories.append(new_traj) | |
augmented_trajectories = torch.cat(augmented_trajectories, dim=0) | |
associated_idxs = torch.cat(list_idxs, dim=0) | |
# Assign subgoals to the new trajectories | |
augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs) | |
# Adjust the rewards based on the new subgoals | |
augmented_trajectories = self.RewardTransform.forward(augmented_trajectories) | |
return augmented_trajectories | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment