This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# At the start of each search, we add dirichlet noise to the prior of the root | |
# to encourage the search to explore new actions. | |
def add_exploration_noise(config: MuZeroConfig, node: Node): | |
actions = list(node.children.keys()) | |
noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions)) | |
frac = config.root_exploration_fraction | |
for a, n in zip(actions, noise): | |
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Core Monte Carlo Tree Search algorithm. | |
# To decide on an action, we run N simulations, always starting at the root of | |
# the search tree and traversing the tree according to the UCB formula until we | |
# reach a leaf node. | |
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, | |
network: Network): | |
min_max_stats = MinMaxStats(config.known_bounds) | |
for _ in range(config.num_simulations): | |
history = action_history.clone() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Select the child with the highest UCB score. | |
def select_child(config: MuZeroConfig, node: Node, | |
min_max_stats: MinMaxStats): | |
_, action, child = max( | |
(ucb_score(config, node, child, min_max_stats), action, | |
child) for action, child in node.children.items()) | |
return action, child |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The score for a node is based on its value, plus an exploration bonus based on | |
# the prior. | |
def ucb_score(config: MuZeroConfig, parent: Node, child: Node, | |
min_max_stats: MinMaxStats) -> float: | |
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / | |
config.pb_c_base) + config.pb_c_init | |
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) | |
prior_score = pb_c * child.prior | |
value_score = min_max_stats.normalize(child.value()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# At the end of a simulation, we propagate the evaluation all the way up the | |
# tree to the root. | |
def backpropagate(search_path: List[Node], value: float, to_play: Player, | |
discount: float, min_max_stats: MinMaxStats): | |
for node in search_path: | |
node.value_sum += value if node.to_play == to_play else -value | |
node.visit_count += 1 | |
min_max_stats.update(node.value()) | |
value = node.reward + discount * value |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def select_action(config: MuZeroConfig, num_moves: int, node: Node, | |
network: Network): | |
visit_counts = [ | |
(child.visit_count, action) for action, child in node.children.items() | |
] | |
t = config.visit_softmax_temperature_fn( | |
num_moves=num_moves, training_steps=network.training_steps()) | |
_, action = softmax_sample(visit_counts, t) | |
return action |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch, | |
weight_decay: float): | |
loss = 0 | |
for image, actions, targets in batch: | |
# Initial step, from the real observation. | |
value, reward, policy_logits, hidden_state = network.initial_inference( | |
image) | |
predictions = [(1.0, value, reward, policy_logits)] | |
# Recurrent steps, from action and previous hidden state. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Game(object): | |
"""A single episode of interaction with the environment.""" | |
def __init__(self, action_space_size: int, discount: float): | |
self.environment = Environment() # Game specific environment. | |
self.history = [] | |
self.rewards = [] | |
self.child_visits = [] | |
self.root_values = [] | |
self.action_space_size = action_space_size |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ReplayBuffer(object): | |
def __init__(self, config: MuZeroConfig): | |
self.window_size = config.window_size | |
self.batch_size = config.batch_size | |
self.buffer = [] | |
def sample_batch(self, num_unroll_steps: int, td_steps: int): | |
games = [self.sample_game() for _ in range(self.batch_size)] | |
game_pos = [(g, self.sample_position(g)) for g in games] | |
return [(g.make_image(i), g.history[i:i + num_unroll_steps], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import networkx as nx | |
# SAMPLE DATA FORMAT | |
#nodes = [('tensorflow', {'count': 13}), | |
# ('pytorch', {'count': 6}), | |
# ('keras', {'count': 6}), | |
# ('scikit', {'count': 2}), | |
# ('opencv', {'count': 5}), | |
# ('spark', {'count': 13}), ...] |