Skip to content

Instantly share code, notes, and snippets.

@dtsaras
Last active October 16, 2024 05:16
Show Gist options
  • Save dtsaras/cc0f24b1983d5ae4e4ee0248b40c9fbf to your computer and use it in GitHub Desktop.
Save dtsaras/cc0f24b1983d5ae4e4ee0248b40c9fbf to your computer and use it in GitHub Desktop.
AlphaZero Torchrl Implementation
from abc import abstractmethod
import copy
from dataclasses import dataclass
from hmac import new
from typing import List, Optional, Iterable
import torch
from torch.distributions.dirichlet import _Dirichlet
from tensordict import TensorDictBase, TensorDict, NestedKey
from tensordict.nn import TensorDictModule, TensorDictSequential
# noinspection PyProtectedMember
from tensordict.nn.common import TensorDictModuleBase
from torchrl.envs import EnvBase
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
from torchrl.objectives.value import ValueEstimatorBase, TDLambdaEstimator
from torchrl.objectives.value.functional import reward2go
from .tensordict_map import MCTS_node, TensorDictMap
class ActionExplorationModule(TensorDictModuleBase):
def __init__(
self,
action_key: str = "action",
action_count_key: str = "children_visits",
action_value_under_uncertainty_key: str = "scores",
):
self.in_keys = [action_value_under_uncertainty_key, action_count_key]
self.out_keys = [action_key]
super().__init__()
self.action_value_key = action_value_under_uncertainty_key
self.action_key = action_key
self.action_count_key = action_count_key
def forward(self, tensordict: TensorDictBase, node: MCTS_node) -> TensorDictBase:
tensordict = tensordict.clone(False)
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
tensordict[self.action_key] = self.explore_action(node)
elif exploration_type() == ExplorationType.MODE:
tensordict[self.action_key] = self.get_greedy_action(node)
return tensordict
def get_greedy_action(self, node: MCTS_node) -> torch.Tensor:
action = torch.argmax(node[self.action_count_key])
# return torch.nn.functional.one_hot(action, node[action_cnt_key].shape[-1])
return action
def explore_action(self, node: MCTS_node) -> torch.Tensor:
action_score = node[self.action_value_key]
max_value = torch.max(action_score)
action = torch.argmax(
torch.rand_like(action_score) * (action_score == max_value)
)
# return torch.nn.functional.one_hot(action, action_value.shape[-1])
return action
class UpdateTreeStrategy:
"""
The strategy to update tree after each rollout. This class uses the given value estimator
to compute a target value after each roll out and compute the mean of target values in the tree.
It also updates the number of time nodes get visited in tree.
Args:
tree: A TensorDictMap that store stats of the tree.
value_estimator: A ValueEstimatorBase that compute target value.
action_key: A key in the rollout TensorDict to store the selected action.
action_value_key: A key in the tree nodes that stores the mean of Q(s, a).
action_count_key: A key in the tree nodes that stores the number of times nodes get visited.
"""
# noinspection PyTypeChecker
def __init__(
self,
value_network: TensorDictModuleBase,
action_key: NestedKey = "action",
use_value_network: bool = True,
# value_estimator: Optional[ValueEstimatorBase] = None,
):
self.action_key = action_key
self.value_network = value_network
self.root: MCTS_node
self.use_value_network = use_value_network
# self.value_estimator = value_estimator or self.get_default_value_network(root)
def update(self, rollout: TensorDictBase) -> None:
target_value = torch.zeros(rollout.batch_size[-1]+1, dtype=torch.float32)
done = torch.zeros_like(target_value, dtype=torch.bool)
done[-1] = True
if rollout[("next", "done")][-1]:
target_value[-1] = rollout[("next", "reward")][-1]
else:
if self.use_value_network:
target_value[-1] = self.value_network(rollout[-1]["next"])["state_value"]
else:
target_value[-1] = 0
target_value = reward2go(target_value, done, gamma=0.99, time_dim=-1)
node = self.root
for idx in range(rollout.batch_size[-1]):
action = rollout[self.action_key][idx]
node = node.get_child(action)
node.value = (
node.value * node.visits + target_value[idx]
) / (node.visits + 1)
node.visits += 1
def start_simulation(self, device=None) -> None:
self.root = MCTS_node.root(device)
class ExpansionStrategy(TensorDictModuleBase):
"""
The rollout policy in expanding tree.
This policy will use to initialize a node when it gets expanded at the first time.
"""
def __init__(
self,
out_keys: List[str],
in_keys: Optional[List[str]] = None,
):
self.in_keys = in_keys #type: ignore
self.out_keys = out_keys #type: ignore
super().__init__()
def forward(self, node: MCTS_node) -> TensorDictBase:
"""
The node to be expanded. The output Tensordict will be used in future
to select action.
Args:
tensordict: The state that need to be explored
Returns:
A initialized statistics to select actions in the future.
"""
if not node.expanded:
self.expand(node)
return node
@abstractmethod
def expand(self, node: MCTS_node) -> None:
pass
def set_node(self, node: MCTS_node) -> None:
self.node = node
class AlphaZeroExpansionStrategy(ExpansionStrategy):
"""
An implementation of Alpha Zero to initialize a node at its first time.
Args:
value_module: a TensorDictModule to initialize a prior for Q(s, a)
module_action_value_key: a key in the output of value_module that contains Q(s, a) values
"""
def __init__(
self,
policy_module: TensorDictModule,
action_count_key: str = "children_visits",
action_value_key: str = "children_values",
module_action_value_key: str = "action_value",
prior_action_value_key: str = "children_priors",
):
super().__init__(
in_keys=policy_module.in_keys,
out_keys=
[
action_value_key,
prior_action_value_key,
action_count_key,
module_action_value_key,
],
)
assert module_action_value_key in policy_module.out_keys
self.policy_module = policy_module
self.action_value_key = module_action_value_key
self.q_sa_key = action_value_key
self.p_sa_key = prior_action_value_key
self.n_sa_key = action_count_key
def expand(self, node: MCTS_node) -> None:
policy_netword_td = node["state"].select(*self.policy_module.in_keys)
policy_netword_td = self.policy_module(policy_netword_td)
p_sa = policy_netword_td[self.action_value_key]
node.set(self.p_sa_key, p_sa) # prior_action_value
node.set(self.q_sa_key, torch.zeros_like(p_sa)) # action_value
node.set(self.n_sa_key, torch.zeros_like(p_sa)) # action_count
class PuctSelectionPolicy(TensorDictModuleBase):
"""
The optimism under uncertainty estimation computed by the PUCT formula in AlphaZero paper:
https://discovery.ucl.ac.uk/id/eprint/10069050/1/alphazero_preprint.pdf
Args:
cpuct: A constant to control exploration
action_value_key: an input key, representing the mean of Q(s, a) for every action `a` at state `s`.
prior_action_value_key: an input key, representing the prior of Q(s, a) for every action `a` at state `s`.
action_count_key: an input key, representing the number of times action `a` is selected at state `s`.
action_value_under_uncertainty_key: an output key, representing the output estimate value using PUCT
"""
def __init__(
self,
cpuct: float = 1.0,
action_value_under_uncertainty_key: str = "scores",
action_value_key: str = "children_values",
prior_action_value_key: str = "children_priors",
action_count_key: str = "children_visits",
):
self.in_keys = [action_value_key, action_count_key, prior_action_value_key]
self.out_keys = [action_value_under_uncertainty_key]
super().__init__()
self.cpuct = cpuct
self.action_value_key = action_value_key
self.prior_action_value_key = prior_action_value_key
self.action_count_key = action_count_key
self.action_value_under_uncertainty_key = action_value_under_uncertainty_key
self.node: MCTS_node
def forward(self, node: MCTS_node) -> TensorDictBase:
# we will always add 1, to avoid zero U values in the first visit of the node. See:
# https://ai.stackexchange.com/questions/25451/how-does-alphazeros-mcts-work-when-starting-from-the-root-node
# for a discussion on this topic.
# TODO: investigate MuZero paper, AlphaZero paper and Bandit based monte-carlo planning to understand what
# is the right implementation. Also check this discussion:
# https://groups.google.com/g/computer-go-archive/c/K9XHb64JSqU
n = torch.sum(node[self.action_count_key], dim=-1) + 1
u_sa = self.cpuct * node[self.prior_action_value_key] * torch.sqrt(n) / (1 + node[self.action_count_key])
optimism_estimation = node[self.action_value_key] + u_sa
node[self.action_value_under_uncertainty_key] = optimism_estimation
return node
def set_node(self, node: MCTS_node) -> None:
self.node = node
def set_node(self, node: MCTS_node) -> None:
self.node = node
class DirichletNoiseModule(TensorDictModuleBase):
def __init__(
self,
alpha: float = 0.3,
epsilon: float = 0.25,
prior_action_value_key: str = "children_priors",
):
self.in_keys = [prior_action_value_key]
self.out_keys = [prior_action_value_key]
super().__init__()
self.alpha = alpha
self.epsilon = epsilon
self.prior_action_value_key = prior_action_value_key
def forward(self, node: MCTS_node) -> TensorDictBase:
p_sa = node[self.prior_action_value_key]
if p_sa.device.type == "mps":
device = p_sa.device
noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa).cpu())
noise = noise.to(device) #type: ignore
else:
noise = _Dirichlet.apply(self.alpha * torch.ones_like(p_sa))
p_sa = (1 - self.epsilon) * p_sa + self.epsilon * noise #type: ignore
node[self.prior_action_value_key] = p_sa
return node
@dataclass
class MctsPolicy(TensorDictSequential):
"""
An implementation of MCTS algorithm.
Args:
expansion_strategy: a policy to initialize stats of a node at its first visit.
selection_strategy: a policy to select action in each state
exploration_strategy: a policy to exploration vs exploitation
"""
def __init__(
self,
expansion_strategy: ExpansionStrategy,
selection_strategy: TensorDictModuleBase = PuctSelectionPolicy(),
exploration_strategy: ActionExplorationModule = ActionExplorationModule(),
batch_size: int = 1,
):
super().__init__()
self.expansion_strategy = expansion_strategy
self.selection_strategy = selection_strategy
self.exploration_strategy = exploration_strategy
self.node: MCTS_node
self.batch_size = batch_size
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if not self.node.expanded:
self.node["state"] = tensordict
self.expansion_strategy.forward(self.node)
self.selection_strategy.forward(self.node)
tensordict = self.exploration_strategy.forward(self.node)
batched_nodes = []
if self.batch_size > 1:
for i in range(self.batch_size):
node = self.node[i]
if not tensordict[i]["terminated"]:
node = node.get_child(tensordict[i]["action"])
batched_nodes.append(node)
self.set_node(torch.stack(batched_nodes)) #type: ignore
else:
self.set_node(self.node.get_child(tensordict["action"]))
return tensordict
def set_node(self, node: MCTS_node) -> None:
self.node = node
@dataclass
class SimulatedSearchPolicy(TensorDictModuleBase):
"""
A simulated search policy. In each step, it simulates `n` rollout of maximum steps of `max_simulation_steps`
using the given policy and then choose the best action given the simulation results.
Args:
policy: a policy to select action in each simulation rollout.
env: an environment to simulate a rollout
num_simulation: the number of simulation
max_simulation_steps: the max steps of each simulated rollout
max_steps: the max steps performed by SimulatedSearchPolicy
noise_module: a module to inject noise in the root node for exploration
"""
def __init__(
self,
policy: MctsPolicy,
tree_updater: UpdateTreeStrategy,
env: EnvBase,
num_simulations: int,
simulation_max_steps: int,
max_steps: int,
noise_module: Optional[DirichletNoiseModule] = DirichletNoiseModule(),
):
self.in_keys = policy.in_keys
self.out_keys = policy.out_keys
super().__init__()
self.policy = policy
self.tree_updater = tree_updater
self.env = env
self.num_simulations = num_simulations
self.simulation_max_steps = simulation_max_steps
self.noise_module = noise_module
self.root_list = []
self.init_state: TensorDict
def forward(self, tensordict: TensorDictBase):
tensordict = tensordict.clone(False)
with torch.no_grad():
self.start_simulation(tensordict)
with set_exploration_type(ExplorationType.RANDOM):
for i in range(self.num_simulations):
self.simulate()
with set_exploration_type(ExplorationType.MODE):
root = self.tree_updater.root
tensordict = self.policy.exploration_strategy(root)
self.root_list.append(root)
return tensordict
def simulate(self) -> None:
self.reset_simulation()
rollout = self.env.rollout(
max_steps=self.simulation_max_steps,
policy=self.policy,
return_contiguous=False,
)
# Resets the environment to the original state # type: ignore
self.env.set_state(self.init_state.clone(True)) # type: ignore
# update the nodes visited during the simulation
self.tree_updater.update(rollout) #type: ignore
def start_simulation(self, tensordict) -> None:
# creates new root node for the MCTS tree
self.tree_updater.start_simulation(tensordict.device)
# make a copy of the initial state
self.init_state = self.env.copy_state()
# initialize and expand the root
self.tree_updater.root["state"] = tensordict
self.policy.expansion_strategy(self.tree_updater.root)
# inject dirichlet noise for exploration
if self.noise_module is not None:
self.noise_module(self.tree_updater.root)
def reset_simulation(self) -> None:
# reset the root's priors
# self.tree_updater.root["children_priors"] = self.root_priors
# reset the policy node to the root
if self.policy.batch_size > 1:
self.policy.set_node(self.tree_updater.root.expand(self.policy.batch_size))
else:
self.policy.set_node(self.tree_updater.root)
from __future__ import annotations
from collections.abc import MutableMapping
from typing import Mapping, TypeVar, Union, List, Sequence
import torch
from tensordict import TensorDict, TensorDictBase, NestedKey, tensorclass
class MCTS_node(TensorDict):
def __init__(self, action: int | torch.Tensor, parent: MCTS_node | None, device = None):
super().__init__(
{
"children_values": torch.tensor([]),
"children_priors": torch.tensor([]),
"children_visits": torch.tensor([]),
"score": torch.tensor([]),
"children": TensorDict({}, batch_size=[], device=device),
"state": TensorDict({}, batch_size=[], device=device),
"truncated": torch.tensor([False]),
}, # type: ignore
batch_size=[],
device=device,
)
self.prior_action: int | torch.Tensor = action
self.parent: MCTS_node | None = parent
@property
def visits(self) -> torch.Tensor:
assert self.parent != None
return self.parent["children_visits"][self.prior_action]
@visits.setter
def visits(self, x) -> None:
assert self.parent != None
self.parent["children_visits"][self.prior_action] = x
@property
def value(self) -> torch.Tensor:
assert self.parent != None
return self.parent["children_values"][self.prior_action]
@value.setter
def value(self, x) -> None:
assert self.parent != None
self.parent["children_values"][self.prior_action] = x
@property
def expanded(self) -> bool:
return self["children_priors"].numel() > 0
def get_child(self, action: int | torch.Tensor) -> MCTS_node:
action_str = str(torch.sym_int(action))
if action_str not in self["children"].keys(leaves_only=True):
self["children"][action_str] = MCTS_node(action, self, self.device)
return self["children"][action_str]
@classmethod
def root(cls, device=None) -> MCTS_node:
return cls(-1, None, device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment