This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def muzero(config: MuZeroConfig): | |
| storage = SharedStorage() | |
| replay_buffer = ReplayBuffer(config) | |
| for _ in range(config.num_actors): | |
| launch_job(run_selfplay, config, storage, replay_buffer) | |
| train_network(config, storage, replay_buffer) | |
| return storage.latest_network() | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class MuZeroConfig(object): | |
| def __init__(self, | |
| action_space_size: int, | |
| max_moves: int, | |
| discount: float, | |
| dirichlet_alpha: float, | |
| num_simulations: int, | |
| batch_size: int, | |
| td_steps: int, | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class SharedStorage(object): | |
| def __init__(self): | |
| self._networks = {} | |
| def latest_network(self) -> Network: | |
| if self._networks: | |
| return self._networks[max(self._networks.keys())] | |
| else: | |
| # policy -> uniform, value -> 0, reward -> 0 | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class ReplayBuffer(object): | |
| def __init__(self, config: MuZeroConfig): | |
| self.window_size = config.window_size | |
| self.batch_size = config.batch_size | |
| self.buffer = [] | |
| def save_game(self, game): | |
| if len(self.buffer) > self.window_size: | |
| self.buffer.pop(0) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # Each self-play job is independent of all others; it takes the latest network | |
| # snapshot, produces a game and makes it available to the training job by | |
| # writing it to a shared replay buffer. | |
| def run_selfplay(config: MuZeroConfig, storage: SharedStorage, | |
| replay_buffer: ReplayBuffer): | |
| while True: | |
| network = storage.latest_network() | |
| game = play_game(config, network) | |
| replay_buffer.save_game(game) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def train_network(config: MuZeroConfig, storage: SharedStorage, | |
| replay_buffer: ReplayBuffer): | |
| network = Network() | |
| learning_rate = config.lr_init * config.lr_decay_rate**( | |
| tf.train.get_global_step() / config.lr_decay_steps) | |
| optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum) | |
| for i in range(config.training_steps): | |
| if i % config.checkpoint_interval == 0: | |
| storage.save_network(i, network) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # Each game is produced by starting at the initial board position, then | |
| # repeatedly executing a Monte Carlo Tree Search to generate moves until the end | |
| # of the game is reached. | |
| def play_game(config: MuZeroConfig, network: Network) -> Game: | |
| game = config.new_game() | |
| while not game.terminal() and len(game.history) < config.max_moves: | |
| # At the root of the search tree we use the representation function to | |
| # obtain a hidden state given the current observation. | |
| root = Node(0) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class Node(object): | |
| def __init__(self, prior: float): | |
| self.visit_count = 0 | |
| self.to_play = -1 | |
| self.prior = prior | |
| self.value_sum = 0 | |
| self.children = {} | |
| self.hidden_state = None | |
| self.reward = 0 | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class NetworkOutput(typing.NamedTuple): | |
| value: float | |
| reward: float | |
| policy_logits: Dict[Action, float] | |
| hidden_state: List[float] | |
| class Network(object): | |
| def initial_inference(self, image) -> NetworkOutput: | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # We expand a node using the value, reward and policy prediction obtained from | |
| # the neural network. | |
| def expand_node(node: Node, to_play: Player, actions: List[Action], | |
| network_output: NetworkOutput): | |
| node.to_play = to_play | |
| node.hidden_state = network_output.hidden_state | |
| node.reward = network_output.reward | |
| policy = {a: math.exp(network_output.policy_logits[a]) for a in actions} | |
| policy_sum = sum(policy.values()) | |
| for action, p in policy.items(): | 
OlderNewer