Skip to content

Instantly share code, notes, and snippets.

@dtsaras
dtsaras / distributed_rl_training.py
Created September 19, 2024 01:49
Distributed training with broadcasting the updated weights
import os
import random
import time
import torch
import torch.distributed.rpc as rpc
import tqdm
from tensordict import TensorDict
import torch.distributed as dist
@dtsaras
dtsaras / HER.py
Last active April 23, 2024 07:02
class HERSubGoalSampler(Transform):
"""Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.
Args:
num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
"""
def __init__(
self,
@dtsaras
dtsaras / mcts_policy.py
Last active October 16, 2024 05:16
AlphaZero Torchrl Implementation
from abc import abstractmethod
import copy
from dataclasses import dataclass
from hmac import new
from typing import List, Optional, Iterable
import torch
from torch.distributions.dirichlet import _Dirichlet
from tensordict import TensorDictBase, TensorDict, NestedKey
from tensordict.nn import TensorDictModule, TensorDictSequential