Skip to content

Instantly share code, notes, and snippets.

@Mononofu
Last active July 14, 2024 11:02
Show Gist options
  • Save Mononofu/6c2d27ea1b3a9b3c1a293ebabed062ed to your computer and use it in GitHub Desktop.
Save Mononofu/6c2d27ea1b3a9b3c1a293ebabed062ed to your computer and use it in GitHub Desktop.
MuZero pseudocode
# Copyright 2022 DeepMind Technologies Limited.
# Licensed under the Apache License, Version 2.0 and CC BY 4.0.
# You may not use this file except in compliance with these licenses.
# Copies of the licenses can be found at https://www.apache.org/licenses/LICENSE-2.0
# and https://creativecommons.org/licenses/by/4.0/legalcode.
"""Pseudocode description of the MuZero algorithm."""
# pylint: disable=unused-argument
# pylint: disable=missing-docstring
# pylint: disable=g-explicit-length-test
import collections
import math
import typing
from typing import Any, Dict, List, Optional
import numpy
import tensorflow.compat.v1 as tf
##########################
####### Helpers ##########
MAXIMUM_FLOAT_VALUE = float('inf')
KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])
class MinMaxStats(object):
"""A class that holds the min-max values of the tree."""
def __init__(self, known_bounds: Optional[KnownBounds]):
self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE
def update(self, value: float):
self.maximum = max(self.maximum, value)
self.minimum = min(self.minimum, value)
def normalize(self, value: float) -> float:
if self.maximum > self.minimum:
# We normalize only when we have set the maximum and minimum values.
return (value - self.minimum) / (self.maximum - self.minimum)
return value
class MuZeroConfig(object):
def __init__(self,
action_space_size: int,
max_moves: int,
discount: float,
dirichlet_alpha: float,
num_simulations: int,
batch_size: int,
td_steps: int,
num_actors: int,
lr_init: float,
lr_decay_steps: float,
visit_softmax_temperature_fn,
known_bounds: Optional[KnownBounds] = None):
### Self-Play
self.action_space_size = action_space_size
self.num_actors = num_actors
self.visit_softmax_temperature_fn = visit_softmax_temperature_fn
self.max_moves = max_moves
self.num_simulations = num_simulations
self.discount = discount
# Root prior exploration noise.
self.root_dirichlet_alpha = dirichlet_alpha
self.root_exploration_fraction = 0.25
# UCB formula
self.pb_c_base = 19652
self.pb_c_init = 1.25
# If we already have some information about which values occur in the
# environment, we can use them to initialize the rescaling.
# This is not strictly necessary, but establishes identical behaviour to
# AlphaZero in board games.
self.known_bounds = known_bounds
### Training
self.training_steps = int(1000e3)
self.checkpoint_interval = int(1e3)
self.window_size = int(1e6)
self.batch_size = batch_size
self.num_unroll_steps = 5
self.td_steps = td_steps
self.weight_decay = 1e-4
self.momentum = 0.9
# Exponential learning rate schedule
self.lr_init = lr_init
self.lr_decay_rate = 0.1
self.lr_decay_steps = lr_decay_steps
def new_game(self):
return Game(self.action_space_size, self.discount)
def make_board_game_config(action_space_size: int, max_moves: int,
dirichlet_alpha: float,
lr_init: float) -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
if num_moves < 30:
return 1.0
else:
return 0.0 # Play according to the max.
return MuZeroConfig(
action_space_size=action_space_size,
max_moves=max_moves,
discount=1.0,
dirichlet_alpha=dirichlet_alpha,
num_simulations=800,
batch_size=2048,
td_steps=max_moves, # Always use Monte Carlo return.
num_actors=3000,
lr_init=lr_init,
lr_decay_steps=400e3,
visit_softmax_temperature_fn=visit_softmax_temperature,
known_bounds=KnownBounds(-1, 1))
def make_go_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=362, max_moves=722, dirichlet_alpha=0.03, lr_init=0.01)
def make_chess_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1)
def make_shogi_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=11259, max_moves=512, dirichlet_alpha=0.15, lr_init=0.1)
def make_atari_config() -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
if training_steps < 500e3:
return 1.0
elif training_steps < 750e3:
return 0.5
else:
return 0.25
return MuZeroConfig(
action_space_size=18,
max_moves=27000, # Half an hour at action repeat 4.
discount=0.997,
dirichlet_alpha=0.25,
num_simulations=50,
batch_size=1024,
td_steps=10,
num_actors=350,
lr_init=0.05,
lr_decay_steps=350e3,
visit_softmax_temperature_fn=visit_softmax_temperature)
class Action(object):
def __init__(self, index: int):
self.index = index
def __hash__(self):
return self.index
def __eq__(self, other):
return self.index == other.index
def __gt__(self, other):
return self.index > other.index
class Player(object):
pass
class Node(object):
def __init__(self, prior: float):
self.visit_count = 0
self.to_play = -1
self.prior = prior
self.value_sum = 0
self.children = {}
self.hidden_state = None
self.reward = 0
def expanded(self) -> bool:
return len(self.children) > 0
def value(self) -> float:
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
class ActionHistory(object):
"""Simple history container used inside the search.
Only used to keep track of the actions executed.
"""
def __init__(self, history: List[Action], action_space_size: int):
self.history = list(history)
self.action_space_size = action_space_size
def clone(self):
return ActionHistory(self.history, self.action_space_size)
def add_action(self, action: Action):
self.history.append(action)
def last_action(self) -> Action:
return self.history[-1]
def action_space(self) -> List[Action]:
return [Action(i) for i in range(self.action_space_size)]
def to_play(self) -> Player:
return Player()
class Environment(object):
"""The environment MuZero is interacting with."""
def step(self, action):
pass
class Game(object):
"""A single episode of interaction with the environment."""
def __init__(self, action_space_size: int, discount: float):
self.environment = Environment() # Game specific environment.
self.history = []
self.rewards = []
self.child_visits = []
self.root_values = []
self.action_space_size = action_space_size
self.discount = discount
def terminal(self) -> bool:
# Game specific termination rules.
pass
def legal_actions(self) -> List[Action]:
# Game specific calculation of legal actions.
return []
def apply(self, action: Action):
reward = self.environment.step(action)
self.rewards.append(reward)
self.history.append(action)
def store_search_statistics(self, root: Node):
sum_visits = sum(child.visit_count for child in root.children.values())
action_space = (Action(index) for index in range(self.action_space_size))
self.child_visits.append([
root.children[a].visit_count / sum_visits if a in root.children else 0
for a in action_space
])
self.root_values.append(root.value())
def make_image(self, state_index: int):
# Game specific feature planes.
return []
def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int,
to_play: Player):
# The value target is the discounted root value of the search tree N steps
# into the future, plus the discounted sum of all rewards until then.
targets = []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(self.root_values):
value = self.root_values[bootstrap_index] * self.discount**td_steps
else:
value = 0
for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
value += reward * self.discount**i # pytype: disable=unsupported-operands
if current_index > 0 and current_index <= len(self.rewards):
last_reward = self.rewards[current_index - 1]
else:
last_reward = None
if current_index < len(self.root_values):
targets.append((value, last_reward, self.child_visits[current_index]))
else:
# States past the end of games are treated as absorbing states.
targets.append((0, last_reward, []))
return targets
def to_play(self) -> Player:
return Player()
def action_history(self) -> ActionHistory:
return ActionHistory(self.history, self.action_space_size)
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
self.buffer.append(game)
def sample_batch(self, num_unroll_steps: int, td_steps: int):
games = [self.sample_game() for _ in range(self.batch_size)]
game_pos = [(g, self.sample_position(g)) for g in games]
return [(g.make_image(i), g.history[i:i + num_unroll_steps],
g.make_target(i, num_unroll_steps, td_steps, g.to_play()))
for (g, i) in game_pos]
def sample_game(self) -> Game:
# Sample game from buffer either uniformly or according to some priority.
return self.buffer[0]
def sample_position(self, game) -> int:
# Sample position from game either uniformly or according to some priority.
return -1
class NetworkOutput(typing.NamedTuple):
value: float
reward: float
policy_logits: Dict[Action, float]
hidden_state: List[float]
class Network(object):
def initial_inference(self, image) -> NetworkOutput:
# representation + prediction function
return NetworkOutput(0, 0, {}, [])
def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
# dynamics + prediction function
return NetworkOutput(0, 0, {}, [])
def get_weights(self):
# Returns the weights of this network.
return []
def training_steps(self) -> int:
# How many steps / batches the network has been trained for.
return 0
class SharedStorage(object):
def __init__(self):
self._networks = {}
def latest_network(self) -> Network:
if self._networks:
return self._networks[max(self._networks.keys())]
else:
# policy -> uniform, value -> 0, reward -> 0
return make_uniform_network()
def save_network(self, step: int, network: Network):
self._networks[step] = network
##### End Helpers ########
##########################
# MuZero training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def muzero(config: MuZeroConfig):
storage = SharedStorage()
replay_buffer = ReplayBuffer(config)
for _ in range(config.num_actors):
launch_job(run_selfplay, config, storage, replay_buffer)
train_network(config, storage, replay_buffer)
return storage.latest_network()
##################################
####### Part 1: Self-Play ########
# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: MuZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
while True:
network = storage.latest_network()
game = play_game(config, network)
replay_buffer.save_game(game)
# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: MuZeroConfig, network: Network) -> Game:
game = config.new_game()
while not game.terminal() and len(game.history) < config.max_moves:
min_max_stats = MinMaxStats(config.known_bounds)
# At the root of the search tree we use the representation function to
# obtain a hidden state given the current observation.
root = Node(0)
current_observation = game.make_image(-1)
network_output = network.initial_inference(current_observation)
expand_node(root, game.to_play(), game.legal_actions(), network_output)
backpropagate([root], network_output.value, game.to_play(), config.discount,
min_max_stats)
add_exploration_noise(config, root)
# We then run a Monte Carlo Tree Search using only action sequences and the
# model learned by the network.
run_mcts(config, root, game.action_history(), network, min_max_stats)
action = select_action(config, len(game.history), root, network)
game.apply(action)
game.store_search_statistics(root)
return game
# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory,
network: Network, min_max_stats: MinMaxStats):
for _ in range(config.num_simulations):
history = action_history.clone()
node = root
search_path = [node]
while node.expanded():
action, node = select_child(config, node, min_max_stats)
history.add_action(action)
search_path.append(node)
# Inside the search tree we use the dynamics function to obtain the next
# hidden state given an action and the previous hidden state.
parent = search_path[-2]
network_output = network.recurrent_inference(parent.hidden_state,
history.last_action())
expand_node(node, history.to_play(), history.action_space(), network_output)
backpropagate(search_path, network_output.value, history.to_play(),
config.discount, min_max_stats)
def select_action(config: MuZeroConfig, num_moves: int, node: Node,
network: Network):
visit_counts = [
(child.visit_count, action) for action, child in node.children.items()
]
t = config.visit_softmax_temperature_fn(
num_moves=num_moves, training_steps=network.training_steps())
_, action = softmax_sample(visit_counts, t)
return action
# Select the child with the highest UCB score.
def select_child(config: MuZeroConfig, node: Node,
min_max_stats: MinMaxStats):
_, action, child = max(
(ucb_score(config, node, child, min_max_stats), action,
child) for action, child in node.children.items())
return action, child
# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: MuZeroConfig, parent: Node, child: Node,
min_max_stats: MinMaxStats) -> float:
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
config.pb_c_base) + config.pb_c_init
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)
prior_score = pb_c * child.prior
if child.visit_count > 0:
value_score = min_max_stats.normalize(child.reward +
config.discount * child.value())
else:
value_score = 0
return prior_score + value_score
# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(node: Node, to_play: Player, actions: List[Action],
network_output: NetworkOutput):
node.to_play = to_play
node.hidden_state = network_output.hidden_state
node.reward = network_output.reward
policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
policy_sum = sum(policy.values())
for action, p in policy.items():
node.children[action] = Node(p / policy_sum)
# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play: Player,
discount: float, min_max_stats: MinMaxStats):
for node in reversed(search_path):
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
value = node.reward + discount * value
# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: MuZeroConfig, node: Node):
actions = list(node.children.keys())
noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
frac = config.root_exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
######### End Self-Play ##########
##################################
##################################
####### Part 2: Training #########
def train_network(config: MuZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
network = Network()
learning_rate = config.lr_init * config.lr_decay_rate**(
tf.train.get_global_step() / config.lr_decay_steps)
optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum)
for i in range(config.training_steps):
if i % config.checkpoint_interval == 0:
storage.save_network(i, network)
batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
update_weights(optimizer, network, batch, config.weight_decay)
storage.save_network(config.training_steps, network)
def scale_gradient(tensor: Any, scale):
"""Scales the gradient for the backward pass."""
return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
weight_decay: float):
loss = 0
for image, actions, targets in batch:
# Initial step, from the real observation.
network_output = network.initial_inference(image)
hidden_state = network_output.hidden_state
predictions = [(1.0, network_output)]
# Recurrent steps, from action and previous hidden state.
for action in actions:
network_output = network.recurrent_inference(hidden_state, action)
hidden_state = network_output.hidden_state
predictions.append((1.0 / len(actions), network_output))
hidden_state = scale_gradient(hidden_state, 0.5)
for k, (prediction, target) in enumerate(zip(predictions, targets)):
gradient_scale, network_output = prediction
target_value, target_reward, target_policy = target
l = tf.nn.softmax_cross_entropy_with_logits(
logits=network_output.policy_logits, labels=target_policy)
l += scalar_loss(network_output.value, target_value)
if k > 0:
l += scalar_loss(network_output.reward, target_reward)
loss += scale_gradient(l, gradient_scale)
loss /= len(batch)
for weights in network.get_weights():
loss += weight_decay * tf.nn.l2_loss(weights)
optimizer.minimize(loss)
def scalar_loss(prediction, target) -> float:
# MSE in board games, cross entropy between categorical values in Atari.
return -1
######### End Training ###########
##################################
################################################################################
############################# End of pseudocode ################################
################################################################################
# Stubs to make the typechecker happy.
def softmax_sample(distribution, temperature: float):
return 0, 0
def launch_job(f, *args):
f(*args)
def make_uniform_network():
return Network()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment