Last active
May 23, 2018 22:53
-
-
Save Kaixhin/0ecbd3f7a86adf55331f9fd21ed24257 to your computer and use it in GitHub Desktop.
Introduction to Monte Carlo Tree Search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Introduction to Monte Carlo Tree Search | |
http://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/ | |
""" | |
from copy import deepcopy | |
import datetime | |
from math import log, sqrt | |
from random import choice | |
# Tic-tac-toe board | |
class Board(): | |
# Helper for dealing with board state | |
def _flatten(self, state): | |
return sum(state, []) | |
# Helper for dealing with board state | |
def _unflatten(self, unrolled_state): | |
state = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] | |
for y in range(3): | |
for x in range(3): | |
state[y][x] = unrolled_state[3 * y + x] | |
return state | |
# Returns the starting state of the game | |
def reset(self): | |
return [[0, 0, 0], [0, 0, 0], [0, 0, 0]] | |
# Finds the current player's number | |
def current_player(self, state): | |
unrolled_state = self._flatten(state) | |
p1_count = sum(e == 1 for e in unrolled_state) | |
p2_count = sum(e == 2 for e in unrolled_state) | |
return 1 if p1_count == p2_count else 2 | |
# Makes a step in the environment | |
def step(self, state, action): | |
current = self.current_player(state) | |
new_state = self._flatten(state) # Performs copy of state | |
new_state[action] = current | |
return self._unflatten(new_state) | |
# Returns the list of legal moves for the current player | |
def legal_actions(self, state): | |
current = self.current_player(state) | |
return [i for i, e in zip(range(9), self._flatten(state)) if e == 0] | |
# Returns winning player's number (1/2), 0 if ongoing, or -1 for draw | |
def winner(self, state): | |
# Check winning states | |
if (state[0][0] == 1 or state[0][0] == 2) and \ | |
((state[0][0] == state[0][1] and state[0][1] == state[0][2]) or | |
(state[0][0] == state[1][1] and state[1][1] == state[2][2]) or | |
(state[0][0] == state[1][0] and state[1][0] == state[2][0])): | |
return state[0][0] | |
elif (state[0][1] == 1 or state[0][1] == 2) and \ | |
state[0][1] == state[1][1] and state[1][1] == state[2][1]: | |
return state[0][1] | |
elif (state[0][2] == 1 or state[0][2] == 2) and \ | |
((state[0][2] == state[1][1] and state[1][1] == state[2][0]) or | |
(state[0][2] == state[1][2] and state[1][2] == state[2][2])): | |
return state[0][2] | |
elif (state[1][0] == 1 or state[1][0] == 2) and \ | |
state[1][0] == state[1][1] and state[1][1] == state[1][2]: | |
return state[1][0] | |
elif (state[2][0] == 1 or state[2][0] == 2) and \ | |
state[2][0] == state[2][1] and state[2][1] == state[2][2]: | |
return state[2][0] | |
elif any(s == 0 for s in self._flatten(state)): | |
return 0 | |
else: | |
return -1 # Assume draws only happen at end of game | |
# Returns a unique hash per unique state | |
def hash(self, state): | |
return ''.join(str(e) for e in self._flatten(state)) | |
# Converts state element encoding for pretty printing | |
def _print_element(self, element): | |
if element == 1: | |
return 'X' | |
elif element == 2: | |
return 'O' | |
else: | |
return ' ' | |
# Pretty prints a state | |
def pretty_print(self, state): | |
pretty_state = map(self._print_element, (state[0][0], state[0][1], state[0][2], state[1][0], state[1][1], state[1][2], state[2][0], state[2][1], state[2][2])) | |
print('Board:\n-----\n|%s%s%s|\n|%s%s%s|\n|%s%s%s|\n-----' % tuple(pretty_state)) | |
# MCTS planner | |
class MCTS(): | |
# Initializes the game history (states only) and the statistics tables | |
def __init__(self, board, **kwargs): | |
self.board = board | |
self.history = [] | |
self.c = sqrt(2) # Exploration parameter | |
self.wins = {} | |
self.plays = {} | |
self.calculation_time = datetime.timedelta(seconds=kwargs.get('time', 10)) # Max amount of time per move calculation | |
self.max_moves = kwargs.get('max_moves', 100) # Max number of moves per rollout | |
# Appends a game state to the history | |
def update(self, state): | |
self.history.append(state) | |
# Calculate and return the best move from the current game state | |
def get_action(self): | |
self.max_depth = 0 | |
state = self.history[-1] | |
player = self.board.current_player(state) | |
legal = self.board.legal_actions(state) | |
# Stop early if no choice to be made | |
if len(legal) == 0: | |
return | |
elif len(legal) == 1: | |
return legal[0] | |
# Run simulations repeatedly until set time elapsed | |
games = 0 | |
begin = datetime.datetime.utcnow() | |
while datetime.datetime.utcnow() - begin < self.calculation_time: | |
self.run_simulation() | |
games += 1 | |
# Store state-action pairs (that are hashable) | |
states_actions = [(self.board.hash(self.board.step(state, a)), a) for a in legal] | |
# Display the number of calls to run_simulation and the time elapsed | |
print('Current Player:', player, '| Simulated Games:', games, '| Search Time:', datetime.datetime.utcnow() - begin) | |
# Pick the action with the highest percentage of wins | |
percent_wins, action = max( | |
(self.wins.get((player, s), 0) / self.plays.get((player, s), 1), a) | |
for s, a in states_actions) | |
# Display the stats for each possible play | |
print('Action: Win Rate (Wins / Plays)') | |
for x in sorted( | |
((100 * self.wins.get((player, s), 0) / self.plays.get((player, s), 1), | |
self.wins.get((player, s), 0), self.plays.get((player, s), 0), a) | |
for s, a in states_actions), | |
reverse=True): | |
print("{3}: {0:.2f}% ({1} / {2})".format(*x)) | |
print("Maximum Search Depth:", self.max_depth) | |
return action | |
# Plays out a pseudorandom game from the current position and updates the statistics tables | |
def run_simulation(self): | |
visited_states = set() | |
history_copy = deepcopy(self.history) # Keep separate copy of canonical game tree | |
state = history_copy[-1] | |
player = self.board.current_player(state) | |
hashable_state = self.board.hash(state) | |
expand = True | |
for t in range(1, self.max_moves + 1): | |
legal = self.board.legal_actions(state) | |
# Store state-action pairs (that are hashable) | |
states_actions = [(self.board.step(state, a), a) for a in legal] | |
if all(self.plays.get((player, self.board.hash(s))) for s, a in states_actions): | |
# If we have stats on all of the legal moves, use Upper Confidence Bound 1 applied to trees (UCT) to choose the next action | |
log_total = log(sum(self.plays[(player, self.board.hash(s))] for s, a in states_actions)) | |
value, action, state = max( | |
((self.wins[(player, self.board.hash(s))] / self.plays[(player, self.board.hash(s))]) + self.c * sqrt(log_total / self.plays[(player, self.board.hash(s))]), a, s) | |
for s, a in states_actions) | |
else: | |
# Otherwise choose a random move | |
state, action = choice(states_actions) | |
history_copy.append(state) | |
hashable_state = self.board.hash(state) | |
# Initialise stats for moving player (if necessary) | |
if expand and (player, hashable_state) not in self.plays: | |
expand = False | |
self.plays[(player, hashable_state)] = 0 | |
self.wins[(player, hashable_state)] = 0 | |
self.max_depth = max(t, self.max_depth) | |
visited_states.add((player, hashable_state)) | |
player = self.board.current_player(state) | |
winner = self.board.winner(state) | |
if winner: | |
break | |
for v_player, v_state in visited_states: # Contains hashable states | |
if (v_player, v_state) not in self.plays: | |
continue | |
self.plays[(v_player, v_state)] += 1 | |
if v_player == winner: | |
self.wins[(v_player, v_state)] += 1 | |
# Play one game | |
if __name__ == '__main__': | |
env = Board() | |
mcts = MCTS(Board()) | |
state, winner = env.reset(), 0 | |
env.pretty_print(state) | |
while not winner: | |
mcts.update(state) | |
action = mcts.get_action() | |
state = env.step(state, action) | |
env.pretty_print(state) | |
winner = env.winner(state) | |
print('Winner: Player', winner) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment