Skip to content

Instantly share code, notes, and snippets.

@suragnair
Last active November 29, 2023 16:15
Show Gist options
  • Save suragnair/3bd7c64fb4b838c4e21d0d490308f621 to your computer and use it in GitHub Desktop.
Save suragnair/3bd7c64fb4b838c4e21d0d490308f621 to your computer and use it in GitHub Desktop.
MCTS for Alpha Zero
def search(s, game, nnet):
if game.gameEnded(s): return -game.gameReward(s)
if s not in visited:
visited.add(s)
P[s], v = nnet.predict(s)
return -v
max_u, best_a = -float("inf"), -1
for a in game.getValidActions(s):
u = Q[s][a] + c_puct*P[s][a]*sqrt(sum(N[s]))/(1+N[s][a])
if u>max_u:
max_u = u
best_a = a
a = best_a
sp = game.nextState(s, a)
v = search(sp, game, nnet)
Q[s][a] = (N[s][a]*Q[s][a] + v)/(N[s][a]+1)
N[s][a] += 1
return -v
@Tanuj1209
Copy link

What is the optimal value you are using for c_puct hyperparameter in your code.

@suragnair
Copy link
Author

We use 1 as a default but I don’t know how sensitive training is to this parameter. Best to try a handful of values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment