This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
import time | |
import torch | |
import torch.distributed.rpc as rpc | |
import tqdm | |
from tensordict import TensorDict | |
import torch.distributed as dist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HERSubGoalSampler(Transform): | |
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index. | |
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states. | |
Args: | |
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4. | |
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx". | |
""" | |
def __init__( | |
self, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import abstractmethod | |
import copy | |
from dataclasses import dataclass | |
from hmac import new | |
from typing import List, Optional, Iterable | |
import torch | |
from torch.distributions.dirichlet import _Dirichlet | |
from tensordict import TensorDictBase, TensorDict, NestedKey | |
from tensordict.nn import TensorDictModule, TensorDictSequential |