Skip to content

Instantly share code, notes, and snippets.

View davidADSP's full-sized avatar

David Foster davidADSP

View GitHub Profile
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
class SharedStorage(object):
def __init__(self):
self._networks = {}
def latest_network(self) -> Network:
if self._networks:
return self._networks[max(self._networks.keys())]
else:
# policy -> uniform, value -> 0, reward -> 0
class MuZeroConfig(object):
def __init__(self,
action_space_size: int,
max_moves: int,
discount: float,
dirichlet_alpha: float,
num_simulations: int,
batch_size: int,
td_steps: int,
def muzero(config: MuZeroConfig):
storage = SharedStorage()
replay_buffer = ReplayBuffer(config)
for _ in range(config.num_actors):
launch_job(run_selfplay, config, storage, replay_buffer)
train_network(config, storage, replay_buffer)
return storage.latest_network()