-
Star
(189)
You must be signed in to star a gist -
Fork
(41)
You must be signed in to fork a gist
-
-
Save qpwo/c538c6f73727e254fdc7fab81024f6e1 to your computer and use it in GitHub Desktop.
""" | |
A minimal implementation of Monte Carlo tree search (MCTS) in Python 3 | |
Luke Harold Miles, July 2019, Public Domain Dedication | |
See also https://en.wikipedia.org/wiki/Monte_Carlo_tree_search | |
https://gist.github.com/qpwo/c538c6f73727e254fdc7fab81024f6e1 | |
""" | |
from abc import ABC, abstractmethod | |
from collections import defaultdict | |
import math | |
class MCTS: | |
"Monte Carlo tree searcher. First rollout the tree then choose a move." | |
def __init__(self, exploration_weight=1): | |
self.Q = defaultdict(int) # total reward of each node | |
self.N = defaultdict(int) # total visit count for each node | |
self.children = dict() # children of each node | |
self.exploration_weight = exploration_weight | |
def choose(self, node): | |
"Choose the best successor of node. (Choose a move in the game)" | |
if node.is_terminal(): | |
raise RuntimeError(f"choose called on terminal node {node}") | |
if node not in self.children: | |
return node.find_random_child() | |
def score(n): | |
if self.N[n] == 0: | |
return float("-inf") # avoid unseen moves | |
return self.Q[n] / self.N[n] # average reward | |
return max(self.children[node], key=score) | |
def do_rollout(self, node): | |
"Make the tree one layer better. (Train for one iteration.)" | |
path = self._select(node) | |
leaf = path[-1] | |
self._expand(leaf) | |
reward = self._simulate(leaf) | |
self._backpropagate(path, reward) | |
def _select(self, node): | |
"Find an unexplored descendent of `node`" | |
path = [] | |
while True: | |
path.append(node) | |
if node not in self.children or not self.children[node]: | |
# node is either unexplored or terminal | |
return path | |
unexplored = self.children[node] - self.children.keys() | |
if unexplored: | |
n = unexplored.pop() | |
path.append(n) | |
return path | |
node = self._uct_select(node) # descend a layer deeper | |
def _expand(self, node): | |
"Update the `children` dict with the children of `node`" | |
if node in self.children: | |
return # already expanded | |
self.children[node] = node.find_children() | |
def _simulate(self, node): | |
"Returns the reward for a random simulation (to completion) of `node`" | |
invert_reward = True | |
while True: | |
if node.is_terminal(): | |
reward = node.reward() | |
return 1 - reward if invert_reward else reward | |
node = node.find_random_child() | |
invert_reward = not invert_reward | |
def _backpropagate(self, path, reward): | |
"Send the reward back up to the ancestors of the leaf" | |
for node in reversed(path): | |
self.N[node] += 1 | |
self.Q[node] += reward | |
reward = 1 - reward # 1 for me is 0 for my enemy, and vice versa | |
def _uct_select(self, node): | |
"Select a child of node, balancing exploration & exploitation" | |
# All children of node should already be expanded: | |
assert all(n in self.children for n in self.children[node]) | |
log_N_vertex = math.log(self.N[node]) | |
def uct(n): | |
"Upper confidence bound for trees" | |
return self.Q[n] / self.N[n] + self.exploration_weight * math.sqrt( | |
log_N_vertex / self.N[n] | |
) | |
return max(self.children[node], key=uct) | |
class Node(ABC): | |
""" | |
A representation of a single board state. | |
MCTS works by constructing a tree of these Nodes. | |
Could be e.g. a chess or checkers board state. | |
""" | |
@abstractmethod | |
def find_children(self): | |
"All possible successors of this board state" | |
return set() | |
@abstractmethod | |
def find_random_child(self): | |
"Random successor of this board state (for more efficient simulation)" | |
return None | |
@abstractmethod | |
def is_terminal(self): | |
"Returns True if the node has no children" | |
return True | |
@abstractmethod | |
def reward(self): | |
"Assumes `self` is terminal node. 1=win, 0=loss, .5=tie, etc" | |
return 0 | |
@abstractmethod | |
def __hash__(self): | |
"Nodes must be hashable" | |
return 123456789 | |
@abstractmethod | |
def __eq__(node1, node2): | |
"Nodes must be comparable" | |
return True |
""" | |
An example implementation of the abstract Node class for use in MCTS | |
If you run this file then you can play against the computer. | |
A tic-tac-toe board is represented as a tuple of 9 values, each either None, | |
True, or False, respectively meaning 'empty', 'X', and 'O'. | |
The board is indexed by row: | |
0 1 2 | |
3 4 5 | |
6 7 8 | |
For example, this game board | |
O - X | |
O X - | |
X - - | |
corrresponds to this tuple: | |
(False, None, True, False, True, None, True, None, None) | |
""" | |
from collections import namedtuple | |
from random import choice | |
from monte_carlo_tree_search import MCTS, Node | |
_TTTB = namedtuple("TicTacToeBoard", "tup turn winner terminal") | |
# Inheriting from a namedtuple is convenient because it makes the class | |
# immutable and predefines __init__, __repr__, __hash__, __eq__, and others | |
class TicTacToeBoard(_TTTB, Node): | |
def find_children(board): | |
if board.terminal: # If the game is finished then no moves can be made | |
return set() | |
# Otherwise, you can make a move in each of the empty spots | |
return { | |
board.make_move(i) for i, value in enumerate(board.tup) if value is None | |
} | |
def find_random_child(board): | |
if board.terminal: | |
return None # If the game is finished then no moves can be made | |
empty_spots = [i for i, value in enumerate(board.tup) if value is None] | |
return board.make_move(choice(empty_spots)) | |
def reward(board): | |
if not board.terminal: | |
raise RuntimeError(f"reward called on nonterminal board {board}") | |
if board.winner is board.turn: | |
# It's your turn and you've already won. Should be impossible. | |
raise RuntimeError(f"reward called on unreachable board {board}") | |
if board.turn is (not board.winner): | |
return 0 # Your opponent has just won. Bad. | |
if board.winner is None: | |
return 0.5 # Board is a tie | |
# The winner is neither True, False, nor None | |
raise RuntimeError(f"board has unknown winner type {board.winner}") | |
def is_terminal(board): | |
return board.terminal | |
def make_move(board, index): | |
tup = board.tup[:index] + (board.turn,) + board.tup[index + 1 :] | |
turn = not board.turn | |
winner = _find_winner(tup) | |
is_terminal = (winner is not None) or not any(v is None for v in tup) | |
return TicTacToeBoard(tup, turn, winner, is_terminal) | |
def to_pretty_string(board): | |
to_char = lambda v: ("X" if v is True else ("O" if v is False else " ")) | |
rows = [ | |
[to_char(board.tup[3 * row + col]) for col in range(3)] for row in range(3) | |
] | |
return ( | |
"\n 1 2 3\n" | |
+ "\n".join(str(i + 1) + " " + " ".join(row) for i, row in enumerate(rows)) | |
+ "\n" | |
) | |
def play_game(): | |
tree = MCTS() | |
board = new_tic_tac_toe_board() | |
print(board.to_pretty_string()) | |
while True: | |
row_col = input("enter row,col: ") | |
row, col = map(int, row_col.split(",")) | |
index = 3 * (row - 1) + (col - 1) | |
if board.tup[index] is not None: | |
raise RuntimeError("Invalid move") | |
board = board.make_move(index) | |
print(board.to_pretty_string()) | |
if board.terminal: | |
break | |
# You can train as you go, or only at the beginning. | |
# Here, we train as we go, doing fifty rollouts each turn. | |
for _ in range(50): | |
tree.do_rollout(board) | |
board = tree.choose(board) | |
print(board.to_pretty_string()) | |
if board.terminal: | |
break | |
def _winning_combos(): | |
for start in range(0, 9, 3): # three in a row | |
yield (start, start + 1, start + 2) | |
for start in range(3): # three in a column | |
yield (start, start + 3, start + 6) | |
yield (0, 4, 8) # down-right diagonal | |
yield (2, 4, 6) # down-left diagonal | |
def _find_winner(tup): | |
"Returns None if no winner, True if X wins, False if O wins" | |
for i1, i2, i3 in _winning_combos(): | |
v1, v2, v3 = tup[i1], tup[i2], tup[i3] | |
if False is v1 is v2 is v3: | |
return False | |
if True is v1 is v2 is v3: | |
return True | |
return None | |
def new_tic_tac_toe_board(): | |
return TicTacToeBoard(tup=(None,) * 9, turn=True, winner=None, terminal=False) | |
if __name__ == "__main__": | |
play_game() |
Add a field for a pointer to the parent which is created in the constructors and ignored by __eq__
Thanks a lot. I'm a civil engineer. This is a little bit beyond my knowledge. Could you introduce a reference for me to study the pointer in Python? Thanks again!
Hi, thanks for the example code. I am not clear why line 67
invert_reward
is initially defined asTrue
? The input node is path[-1], and path[-1] will add the reward later, it seems it will add an inversed reward?The hardest thing with MCTS is the off by one bugs. Looking at this you might be right. I tested the algorithm and it seemed to work though. Is MCTS so good that you can give it reverse rewards and it still works?
Thank you for the reply. I found If you change line 52 of
tictactoe.py
to return a reward 1 for the current board and setinvert_reward
initially toFalse
, it also works. It seems easier to understand then.
Thanks to @qpwo for the excellent code—I’ve learned a lot from it! Regarding invert_reward, I chose not to modify the turn attribute in each BoardState. Instead, each BoardState (representing a Node in the Tree) simply reflects which player's turn it is and the board configuration as a tuple. In the simulation phase, if the starting node’s player matches the terminal node’s player, the starting node should receive a positive reward. Otherwise, it should receive a negative reward.
See more in my implementation: https://github.com/KnightZhang625/mcts_tic_tac
Oh nice, much less error-prone!
Thanks for your code. I have a question about how to store the nodes in the MCTS. For example, if I set up a leaf node, how can I index its parent node and children nodes?