Skip to content

Instantly share code, notes, and snippets.

@davidADSP
Created November 29, 2019 14:13
Show Gist options
  • Select an option

  • Save davidADSP/5b1284d1d85e3ef47f39339f8e538298 to your computer and use it in GitHub Desktop.

Select an option

Save davidADSP/5b1284d1d85e3ef47f39339f8e538298 to your computer and use it in GitHub Desktop.
def train_network(config: MuZeroConfig, storage: SharedStorage,
replay_buffer: ReplayBuffer):
network = Network()
learning_rate = config.lr_init * config.lr_decay_rate**(
tf.train.get_global_step() / config.lr_decay_steps)
optimizer = tf.train.MomentumOptimizer(learning_rate, config.momentum)
for i in range(config.training_steps):
if i % config.checkpoint_interval == 0:
storage.save_network(i, network)
batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
update_weights(optimizer, network, batch, config.weight_decay)
storage.save_network(config.training_steps, network)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment