Last active
October 16, 2024 05:16
-
-
Save dtsaras/cc0f24b1983d5ae4e4ee0248b40c9fbf to your computer and use it in GitHub Desktop.
AlphaZero Torchrl Implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import abstractmethod | |
import copy | |
from dataclasses import dataclass | |
from hmac import new | |
from typing import List, Optional, Iterable | |
import torch | |
from torch.distributions.dirichlet import _Dirichlet | |
from tensordict import TensorDictBase, TensorDict, NestedKey | |
from tensordict.nn import TensorDictModule, TensorDictSequential | |
# noinspection PyProtectedMember | |
from tensordict.nn.common import TensorDictModuleBase | |
from torchrl.envs import EnvBase | |
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type | |
from torchrl.objectives.value import ValueEstimatorBase, TDLambdaEstimator | |
from torchrl.objectives.value.functional import reward2go | |
from .tensordict_map import MCTS_node, TensorDictMap | |
class ActionExplorationModule(TensorDictModuleBase): | |
def __init__( | |
self, | |
action_key: str = "action", | |
action_count_key: str = "children_visits", | |
action_value_under_uncertainty_key: str = "scores", | |
): | |
self.in_keys = [action_value_under_uncertainty_key, action_count_key] | |
self.out_keys = [action_key] | |
super().__init__() | |
self.action_value_key = action_value_under_uncertainty_key | |
self.action_key = action_key | |
self.action_count_key = action_count_key | |
def forward(self, tensordict: TensorDictBase, node: MCTS_node) -> TensorDictBase: | |
tensordict = tensordict.clone(False) | |
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None: | |
tensordict[self.action_key] = self.explore_action(node) | |
elif exploration_type() == ExplorationType.MODE: | |
tensordict[self.action_key] = self.get_greedy_action(node) | |
return tensordict | |
def get_greedy_action(self, node: MCTS_node) -> torch.Tensor: | |
action = torch.argmax(node[self.action_count_key]) | |
# return torch.nn.functional.one_hot(action, node[action_cnt_key].shape[-1]) | |
return action | |
def explore_action(self, node: MCTS_node) -> torch.Tensor: | |
action_score = node[self.action_value_key] | |
max_value = torch.max(action_score) | |
action = torch.argmax( | |
torch.rand_like(action_score) * (action_score == max_value) | |
) | |
# return torch.nn.functional.one_hot(action, action_value.shape[-1]) | |
return action | |
class UpdateTreeStrategy: | |
""" | |
The strategy to update tree after each rollout. This class uses the given value estimator | |
to compute a target value after each roll out and compute the mean of target values in the tree. | |
It also updates the number of time nodes get visited in tree. | |
Args: | |
tree: A TensorDictMap that store stats of the tree. | |
value_estimator: A ValueEstimatorBase that compute target value. | |
action_key: A key in the rollout TensorDict to store the selected action. | |
action_value_key: A key in the tree nodes that stores the mean of Q(s, a). | |
action_count_key: A key in the tree nodes that stores the number of times nodes get visited. | |
""" | |
# noinspection PyTypeChecker | |
def __init__( | |
self, | |
value_network: TensorDictModuleBase, | |
action_key: NestedKey = "action", | |
use_value_network: bool = True, | |
# value_estimator: Optional[ValueEstimatorBase] = None, | |
): | |
self.action_key = action_key | |
self.value_network = value_network | |
self.root: MCTS_node | |
self.use_value_network = use_value_network | |
# self.value_estimator = value_estimator or self.get_default_value_network(root) | |
def update(self, rollout: TensorDictBase) -> None: | |
target_value = torch.zeros(rollout.batch_size[-1]+1, dtype=torch.float32) | |
done = torch.zeros_like(target_value, dtype=torch.bool) | |
done[-1] = True | |
if rollout[("next", "done")][-1]: | |
target_value[-1] = rollout[("next", "reward")][-1] | |
else: | |
if self.use_value_network: | |
target_value[-1] = self.value_network(rollout[-1]["next"])["state_value"] | |
else: | |
target_value[-1] = 0 | |
target_value = reward2go(target_value, done, gamma=0.99, time_dim=-1) | |
node = self.root | |
for idx in range(rollout.batch_size[-1]): | |
action = rollout[self.action_key][idx] | |
node = node.get_child(action) | |
node.value = ( | |
node.value * node.visits + target_value[idx] | |
) / (node.visits + 1) | |
node.visits += 1 | |
def start_simulation(self, device=None) -> None: | |
self.root = MCTS_node.root(device) | |
class ExpansionStrategy(TensorDictModuleBase): | |
""" | |
The rollout policy in expanding tree. | |
This policy will use to initialize a node when it gets expanded at the first time. | |
""" | |
def __init__( | |
self, | |
out_keys: List[str], | |
in_keys: Optional[List[str]] = None, | |
): | |
self.in_keys = in_keys #type: ignore | |
self.out_keys = out_keys #type: ignore | |
super().__init__() | |
def forward(self, node: MCTS_node) -> TensorDictBase: | |
""" | |
The node to be expanded. The output Tensordict will be used in future | |
to select action. | |
Args: | |
tensordict: The state that need to be explored | |
Returns: | |
A initialized statistics to select actions in the future. | |
""" | |
if not node.expanded: | |
self.expand(node) | |
return node | |
@abstractmethod | |
def expand(self, node: MCTS_node) -> None: | |
pass | |
def set_node(self, node: MCTS_node) -> None: | |
self.node = node | |
class AlphaZeroExpansionStrategy(ExpansionStrategy): | |
""" | |
An implementation of Alpha Zero to initialize a node at its first time. | |
Args: | |
value_module: a TensorDictModule to initialize a prior for Q(s, a) | |
module_action_value_key: a key in the output of value_module that contains Q(s, a) values | |
""" | |
def __init__( | |
self, | |
policy_module: TensorDictModule, | |
action_count_key: str = "children_visits", | |
action_value_key: str = "children_values", | |
module_action_value_key: str = "action_value", | |
prior_action_value_key: str = "children_priors", | |
): | |
super().__init__( | |
in_keys=policy_module.in_keys, | |
out_keys= | |
[ | |
action_value_key, | |
prior_action_value_key, | |
action_count_key, | |
module_action_value_key, | |
], | |
) | |
assert module_action_value_key in policy_module.out_keys | |
self.policy_module = policy_module | |
self.action_value_key = module_action_value_key | |
self.q_sa_key = action_value_key | |
self.p_sa_key = prior_action_value_key | |
self.n_sa_key = action_count_key | |
def expand(self, node: MCTS_node) -> None: | |
policy_netword_td = node["state"].select(*self.policy_module.in_keys) | |
policy_netword_td = self.policy_module(policy_netword_td) | |
p_sa = policy_netword_td[self.action_value_key] | |
node.set(self.p_sa_key, p_sa) # prior_action_value | |
node.set(self.q_sa_key, torch.zeros_like(p_sa)) # action_value | |
node.set(self.n_sa_key, torch.zeros_like(p_sa)) # action_count | |
class PuctSelectionPolicy(TensorDictModuleBase): | |
""" | |
The optimism under uncertainty estimation computed by the PUCT formula in AlphaZero paper: | |
https://discovery.ucl.ac.uk/id/eprint/10069050/1/alphazero_preprint.pdf | |
Args: | |
cpuct: A constant to control exploration | |
action_value_key: an input key, representing the mean of Q(s, a) for every action `a` at state `s`. | |
prior_action_value_key: an input key, representing the prior of Q(s, a) for every action `a` at state `s`. | |
action_count_key: an input key, representing the number of times action `a` is selected at state `s`. | |
action_value_under_uncertainty_key: an output key, representing the output estimate value using PUCT | |
""" | |
def __init__( | |
self, | |
cpuct: float = 1.0, | |
action_value_under_uncertainty_key: str = "scores", | |
action_value_key: str = "children_values", | |
prior_action_value_key: str = "children_priors", | |
action_count_key: str = "children_visits", | |
): | |
self.in_keys = [action_value_key, action_count_key, prior_action_value_key] | |
self.out_keys = [action_value_under_uncertainty_key] | |
super().__init__() | |
self.cpuct = cpuct | |
self.action_value_key = action_value_key | |
self.prior_action_value_key = prior_action_value_key | |
self.action_count_key = action_count_key | |
self.action_value_under_uncertainty_key = action_value_under_uncertainty_key | |
self.node: MCTS_node | |
def forward(self, node: MCTS_node) -> TensorDictBase: | |
# we will always add 1, to avoid zero U values in the first visit of the node. See: | |
# https://ai.stackexchange.com/questions/25451/how-does-alphazeros-mcts-work-when-starting-from-the-root-node | |
# for a discussion on this topic. | |
# TODO: investigate MuZero paper, AlphaZero paper and Bandit based monte-carlo planning to understand what | |
# is the right implementation. Also check this discussion: | |
# https://groups.google.com/g/computer-go-archive/c/K9XHb64JSqU | |
n = torch.sum(node[self.action_count_key], dim=-1) + 1 | |
u_sa = self.cpuct * node[self.prior_action_value_key] * torch.sqrt(n) / (1 + node[self.action_count_key]) | |
optimism_estimation = node[self.action_value_key] + u_sa | |
node[self.action_value_under_uncertainty_key] = optimism_estimation | |
return node | |
def set_node(self, node: MCTS_node) -> None: | |
self.node = node | |
def set_node(self, node: MCTS_node) -> None: | |
self.node = node | |
class DirichletNoiseModule(TensorDictModuleBase): | |
def __init__( | |
self, | |
alpha: float = 0.3, | |
epsilon: float = 0.25, | |
prior_action_value_key: str = "children_priors", | |
): | |
self.in_keys = [prior_action_value_key] | |
self.out_keys = [prior_action_value_key] | |
super().__init__() | |
self.alpha = alpha | |
self.epsilon = epsilon | |
self.prior_action_value_key = prior_action_value_key | |
def forward(self, node: MCTS_node) -> TensorDictBase: | |
p_sa = node[self.prior_action_value_key] | |
if p_sa.device.type == "mps": | |
device = p_sa.device | |
noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa).cpu()) | |
noise = noise.to(device) #type: ignore | |
else: | |
noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa)) | |
p_sa = (1 - self.epsilon) * p_sa + self.epsilon * noise #type: ignore | |
node[self.prior_action_value_key] = p_sa | |
return node | |
@dataclass | |
class MctsPolicy(TensorDictSequential): | |
""" | |
An implementation of MCTS algorithm. | |
Args: | |
expansion_strategy: a policy to initialize stats of a node at its first visit. | |
selection_strategy: a policy to select action in each state | |
exploration_strategy: a policy to exploration vs exploitation | |
""" | |
def __init__( | |
self, | |
expansion_strategy: ExpansionStrategy, | |
selection_strategy: TensorDictModuleBase = PuctSelectionPolicy(), | |
exploration_strategy: ActionExplorationModule = ActionExplorationModule(), | |
batch_size: int = 1, | |
): | |
super().__init__() | |
self.expansion_strategy = expansion_strategy | |
self.selection_strategy = selection_strategy | |
self.exploration_strategy = exploration_strategy | |
self.node: MCTS_node | |
self.batch_size = batch_size | |
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |
if not self.node.expanded: | |
self.node["state"] = tensordict | |
self.expansion_strategy.forward(self.node) | |
self.selection_strategy.forward(self.node) | |
tensordict = self.exploration_strategy.forward(self.node) | |
batched_nodes = [] | |
if self.batch_size > 1: | |
for i in range(self.batch_size): | |
node = self.node[i] | |
if not tensordict[i]["terminated"]: | |
node = node.get_child(tensordict[i]["action"]) | |
batched_nodes.append(node) | |
self.set_node(torch.stack(batched_nodes)) #type: ignore | |
else: | |
self.set_node(self.node.get_child(tensordict["action"])) | |
return tensordict | |
def set_node(self, node: MCTS_node) -> None: | |
self.node = node | |
@dataclass | |
class SimulatedSearchPolicy(TensorDictModuleBase): | |
""" | |
A simulated search policy. In each step, it simulates `n` rollout of maximum steps of `max_simulation_steps` | |
using the given policy and then choose the best action given the simulation results. | |
Args: | |
policy: a policy to select action in each simulation rollout. | |
env: an environment to simulate a rollout | |
num_simulation: the number of simulation | |
max_simulation_steps: the max steps of each simulated rollout | |
max_steps: the max steps performed by SimulatedSearchPolicy | |
noise_module: a module to inject noise in the root node for exploration | |
""" | |
def __init__( | |
self, | |
policy: MctsPolicy, | |
tree_updater: UpdateTreeStrategy, | |
env: EnvBase, | |
num_simulations: int, | |
simulation_max_steps: int, | |
max_steps: int, | |
noise_module: Optional[DirichletNoiseModule] = DirichletNoiseModule(), | |
): | |
self.in_keys = policy.in_keys | |
self.out_keys = policy.out_keys | |
super().__init__() | |
self.policy = policy | |
self.tree_updater = tree_updater | |
self.env = env | |
self.num_simulations = num_simulations | |
self.simulation_max_steps = simulation_max_steps | |
self.noise_module = noise_module | |
self.root_list = [] | |
self.init_state: TensorDict | |
def forward(self, tensordict: TensorDictBase): | |
tensordict = tensordict.clone(False) | |
with torch.no_grad(): | |
self.start_simulation(tensordict) | |
with set_exploration_type(ExplorationType.RANDOM): | |
for i in range(self.num_simulations): | |
self.simulate() | |
with set_exploration_type(ExplorationType.MODE): | |
root = self.tree_updater.root | |
tensordict = self.policy.exploration_strategy(root) | |
self.root_list.append(root) | |
return tensordict | |
def simulate(self) -> None: | |
self.reset_simulation() | |
rollout = self.env.rollout( | |
max_steps=self.simulation_max_steps, | |
policy=self.policy, | |
return_contiguous=False, | |
) | |
# Resets the environment to the original state # type: ignore | |
self.env.set_state(self.init_state.clone(True)) # type: ignore | |
# update the nodes visited during the simulation | |
self.tree_updater.update(rollout) #type: ignore | |
def start_simulation(self, tensordict) -> None: | |
# creates new root node for the MCTS tree | |
self.tree_updater.start_simulation(tensordict.device) | |
# make a copy of the initial state | |
self.init_state = self.env.copy_state() | |
# initialize and expand the root | |
self.tree_updater.root["state"] = tensordict | |
self.policy.expansion_strategy(self.tree_updater.root) | |
# inject dirichlet noise for exploration | |
if self.noise_module is not None: | |
self.noise_module(self.tree_updater.root) | |
def reset_simulation(self) -> None: | |
# reset the root's priors | |
# self.tree_updater.root["children_priors"] = self.root_priors | |
# reset the policy node to the root | |
if self.policy.batch_size > 1: | |
self.policy.set_node(self.tree_updater.root.expand(self.policy.batch_size)) | |
else: | |
self.policy.set_node(self.tree_updater.root) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from collections.abc import MutableMapping | |
from typing import Mapping, TypeVar, Union, List, Sequence | |
import torch | |
from tensordict import TensorDict, TensorDictBase, NestedKey, tensorclass | |
class MCTS_node(TensorDict): | |
def __init__(self, action: int | torch.Tensor, parent: MCTS_node | None, device = None): | |
super().__init__( | |
{ | |
"children_values": torch.tensor([]), | |
"children_priors": torch.tensor([]), | |
"children_visits": torch.tensor([]), | |
"score": torch.tensor([]), | |
"children": TensorDict({}, batch_size=[], device=device), | |
"state": TensorDict({}, batch_size=[], device=device), | |
"truncated": torch.tensor([False]), | |
}, # type: ignore | |
batch_size=[], | |
device=device, | |
) | |
self.prior_action: int | torch.Tensor = action | |
self.parent: MCTS_node | None = parent | |
@property | |
def visits(self) -> torch.Tensor: | |
assert self.parent != None | |
return self.parent["children_visits"][self.prior_action] | |
@visits.setter | |
def visits(self, x) -> None: | |
assert self.parent != None | |
self.parent["children_visits"][self.prior_action] = x | |
@property | |
def value(self) -> torch.Tensor: | |
assert self.parent != None | |
return self.parent["children_values"][self.prior_action] | |
@value.setter | |
def value(self, x) -> None: | |
assert self.parent != None | |
self.parent["children_values"][self.prior_action] = x | |
@property | |
def expanded(self) -> bool: | |
return self["children_priors"].numel() > 0 | |
def get_child(self, action: int | torch.Tensor) -> MCTS_node: | |
action_str = str(torch.sym_int(action)) | |
if action_str not in self["children"].keys(leaves_only=True): | |
self["children"][action_str] = MCTS_node(action, self, self.device) | |
return self["children"][action_str] | |
@classmethod | |
def root(cls, device=None) -> MCTS_node: | |
return cls(-1, None, device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment