Last active
August 16, 2019 06:39
-
-
Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.
MCTS tictactoe, play 2 against each other, or play against it yourself
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Based on tutorial from https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/ | |
# Author: Kyle Kastner | |
# License: BSD 3-Clause | |
from __future__ import print_function | |
import random | |
import copy | |
import numpy as np | |
import time | |
import argparse | |
import sys | |
global_random = np.random.RandomState(1989) | |
class Board(object): | |
def __init__(self): | |
self.player_symbols = ["X", "O"] | |
def start(self): | |
board = [" "] * 9 | |
return tuple(board) | |
def current_player(self, board): | |
player = None | |
player_counts = [0, 0] | |
for n in range(len(self.player_symbols)): | |
all_syms = [b for b in board if b == self.player_symbols[n]] | |
player_counts[n] = len(all_syms) | |
if player_counts[0] == player_counts[1]: | |
return 1 | |
else: | |
return 2 | |
def is_available(self, board, move): | |
if board[move] != " ": | |
return False | |
else: | |
return True | |
def legal_moves(self, board): | |
move_opts = [0, 1, 2, 3, 4, 5, 6, 7, 8] | |
return [mo for mo in move_opts if self.is_available(board, mo)] | |
def next_state(self, board, move): | |
new_board = copy.copy(list(board)) | |
player = self.current_player(board) | |
new_board[move] = self.player_symbols[player - 1] | |
return tuple(new_board) | |
def is_complete(self, board_history): | |
# -1 tie | |
# 0 continue | |
# 1 player1 win | |
# 2 player2 win | |
board = board_history[-1] | |
board_not_full = len([b for b in board if b != " "]) != len(board) | |
game_won = False | |
winner = "" | |
# simple check for wins | |
# horizontal | |
if board[0] == board[1] == board[2] and board[0] != " ": | |
game_won = True | |
winner = board[0] | |
elif board[3] == board[4] == board[5] and board[3] != " ": | |
game_won = True | |
winner = board[3] | |
elif board[6] == board[7] == board[8] and board[6] != " ": | |
game_won = True | |
winner = board[6] | |
# vertical | |
elif board[0] == board[3] == board[6] and board[0] != " ": | |
game_won = True | |
winner = board[0] | |
elif board[1] == board[4] == board[7] and board[1] != " ": | |
game_won = True | |
winner = board[1] | |
elif board[2] == board[5] == board[8] and board[2] != " ": | |
game_won = True | |
winner = board[2] | |
# diagonal | |
elif board[0] == board[4] == board[8] and board[0] != " ": | |
game_won = True | |
winner = board[0] | |
elif board[2] == board[4] == board[6] and board[2] != " ": | |
game_won = True | |
winner = board[2] | |
if board_not_full: | |
if game_won: | |
if winner == self.player_symbols[0]: | |
return 1 | |
else: | |
return 2 | |
else: | |
return 0 | |
elif not board_not_full: | |
if game_won: | |
if winner == self.player_symbols[0]: | |
return 1 | |
else: | |
return 2 | |
else: | |
return -1 | |
def draw(self, board): | |
tmp_board = copy.copy(list(board)) | |
for i in range(len(tmp_board)): | |
if tmp_board[i] == " ": | |
tmp_board[i] = "({})".format(i) | |
else: | |
tmp_board[i] = " {} ".format(tmp_board[i]) | |
print(' | |') | |
print(' ' + tmp_board[6] + ' | ' + tmp_board[7] + ' | ' + tmp_board[8]) | |
print(' | |') | |
print('----------------') | |
print(' | |') | |
print(' ' + tmp_board[3] + ' | ' + tmp_board[4] + ' | ' + tmp_board[5]) | |
print(' | |') | |
print('----------------') | |
print(' | |') | |
print(' ' + tmp_board[0] + ' | ' + tmp_board[1] + ' | ' + tmp_board[2]) | |
print(' | |') | |
print("") | |
class MCTS(object): | |
def __init__(self, board, state_history=None, runtime_s=10, horizon=100, | |
ucb_weight=1.4, | |
verbose=False): | |
# policy can be ucb1, random | |
self.board = board | |
self.runtime_s = runtime_s | |
self.horizon = horizon | |
self.ucb_weight = ucb_weight | |
self.plays = {} | |
self.rewards = {} | |
self.verbose = verbose | |
if state_history is None: | |
self.state_history = [board.start()] | |
else: | |
self.state_history = state_history | |
def update(self, state): | |
self.state_history.append(state) | |
def estimate(self, policy="uct"): | |
if policy == "uct": | |
pass | |
elif policy == "random": | |
pass | |
else: | |
raise ValueError("Unknown value policy={}".format(policy)) | |
self.max_depth = 0 | |
state = self.state_history[-1] | |
player = self.board.current_player(self.state_history[-1]) | |
moves = self.board.legal_moves(self.state_history[-1]) | |
if len(moves) == 0: | |
return | |
elif len(moves) == 1: | |
return moves[0] | |
games = 0 | |
start_time = time.time() | |
last_t = start_time | |
print("Player {}({})'s turn".format(player, self.board.player_symbols[player - 1])) | |
while time.time() - start_time < self.runtime_s: | |
this_t = time.time() | |
if this_t - last_t > 1: | |
last_t = this_t | |
print("Calculating...") | |
if policy == "uct": | |
self.rollout_uct() | |
elif policy == "random": | |
self.rollout_random() | |
else: | |
raise ValueError("Unknown value policy={}".format(policy)) | |
games += 1 | |
print("") | |
end_time = time.time() - start_time | |
moves_states = [(m, self.board.next_state(state, m)) for m in moves] | |
if self.verbose: | |
print("Number of sim games {}, total time {}".format(games, end_time)) | |
percent, move = max( | |
(self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), p) for p, S in moves_states) | |
if self.verbose: | |
for x in sorted( | |
((100 * self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), | |
self.plays.get((player, S), 1), | |
self.rewards.get((player, S), 0), | |
p) for p, S in moves_states), reverse=True): | |
print("{3}: {0:.2f}% ({2} / {1})".format(*x)) | |
return move | |
def rollout_uct(self): | |
plays, rewards = self.plays, self.rewards | |
exploration = self.ucb_weight | |
visited_states = {} | |
states_copy = copy.copy(self.state_history) | |
state = states_copy[-1] | |
player = self.board.current_player(states_copy[-1]) | |
expand = True | |
for t in range(self.horizon + 1): | |
moves = self.board.legal_moves(states_copy[-1]) | |
moves_states = [(p, self.board.next_state(state, p)) for p in moves] | |
if all(plays.get((player, S)) for p, S in moves_states): | |
log_total = np.log(sum(plays[(player, S)] for p, S in moves_states)) | |
basic_move_triples = [((rewards[(player, S)] / float(plays[(player, S)])) + exploration * np.sqrt(log_total / float(plays[(player, S)])), p, S) for p, S in moves_states] | |
value, move, state = max(basic_move_triples) | |
else: | |
# random move | |
global_random.shuffle(moves) | |
move = moves[0] | |
state = self.board.next_state(states_copy[-1], move) | |
states_copy.append(state) | |
if expand and (player, state) not in self.plays: | |
expand = False | |
self.plays[(player, state)] = 0 | |
self.rewards[(player, state)] = 0 | |
if t > self.max_depth: | |
self.max_depth = t | |
visited_states[(player, state)] = None | |
player = self.board.current_player(states_copy[-1]) | |
complete = self.board.is_complete(states_copy) | |
if complete != 0: | |
break | |
for player, state in visited_states.keys(): | |
if (player, state) not in self.plays: | |
continue | |
self.plays[(player, state)] += 1 | |
if player == complete: | |
self.rewards[(player, state)] += 1 | |
# ties | |
if complete == -1: | |
self.rewards[(player, state)] += 1 | |
def rollout_random(self): | |
visited_states = {} | |
states_copy = copy.copy(self.state_history) | |
state = states_copy[-1] | |
player = self.board.current_player(states_copy[-1]) | |
expand = True | |
for t in range(self.horizon + 1): | |
moves = self.board.legal_moves(states_copy[-1]) | |
global_random.shuffle(moves) | |
move = moves[0] | |
state = self.board.next_state(states_copy[-1], move) | |
states_copy.append(state) | |
if expand and (player, state) not in self.plays: | |
expand = False | |
self.plays[(player, state)] = 0 | |
self.rewards[(player, state)] = 0 | |
if t > self.max_depth: | |
self.max_depth = t | |
visited_states[(player, state)] = None | |
player = self.board.current_player(states_copy[-1]) | |
complete = self.board.is_complete(states_copy) | |
if complete != 0: | |
break | |
for player, state in visited_states.keys(): | |
if (player, state) not in self.plays: | |
continue | |
self.plays[(player, state)] += 1 | |
if player == complete: | |
self.rewards[(player, state)] += 1 | |
# ties | |
if complete == -1: | |
self.rewards[(player, state)] += 1 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Demo of player/MCTS for TIC TAC TOE. Use -i to play against the machine, or -a to make the machine play itself.") | |
parser.add_argument("-i", "--interactive", action="store_true", | |
default=False, | |
help="Play against the computer yourself. For a beatable computer, try -r .1 or lower") | |
parser.add_argument("-a", "--automatic", action="store_true", | |
default=False, | |
help="Play the computer against itself.") | |
parser.add_argument("-r", "--roundtime", type=float, default=3) | |
parser.add_argument("-n", "--no_verbose", action="store_false", default=True) | |
args = parser.parse_args() | |
automatic = args.automatic | |
interactive = args.interactive | |
if not automatic and not interactive: | |
parser.print_help() | |
sys.exit(1) | |
if automatic and interactive: | |
print("Must choose either -i or -a, not both!") | |
sys.exit(1) | |
while True: | |
board = Board() | |
roundtime = args.roundtime | |
verbose = args.no_verbose | |
mcts1 = MCTS(board, runtime_s=roundtime, verbose=verbose) | |
mcts2 = MCTS(board, runtime_s=roundtime, verbose=verbose) | |
board_history = [board.start()] | |
if not args.interactive: | |
# randomly switch player order (still called 1, 2 but symbols swap) | |
if global_random.randint(18888) % 2: | |
board.player_symbols = ["O", "X"] | |
print("Player {} ({}), Player {}, ({})".format(1, board.player_symbols[0], 2, board.player_symbols[1])) | |
# inner game loop | |
while True: | |
complete = board.is_complete(board_history) | |
board.draw(board_history[-1]) | |
if complete == 0: | |
player = board.current_player(board_history[-1]) | |
if player == 1: | |
if args.interactive: | |
move = "100" | |
get_move = True | |
while get_move: | |
print("Human player {}, next move? (0-8)".format(1)) | |
move = raw_input() | |
move_opts = ["0", "1", "2", "3", "4", "5", "6", "7", "8"] | |
if board.is_available(board_history[-1], int(move)) and str(move) in move_opts: | |
get_move = False | |
break | |
move = int(move) | |
else: | |
move = mcts1.estimate(policy="uct") | |
else: | |
move = mcts2.estimate(policy="uct") | |
""" | |
# True random, naive moves | |
moves = board.legal_moves(board_history[-1]) | |
global_random.shuffle(moves) | |
move = moves[0] | |
""" | |
next_board = board.next_state(board_history[-1], move) | |
board_history.append(next_board) | |
mcts1.update(next_board) | |
mcts2.update(next_board) | |
else: | |
# evaluate | |
print("Game over!") | |
if complete == 1: | |
print("Player 1 wins.") | |
elif complete == 2: | |
print("Player 2 wins.") | |
elif complete == -1: | |
print("Tie game.") | |
break | |
print("Play again (y/n)?") | |
choice = raw_input() | |
if str(choice) != "y": | |
print("Thanks for playing!") | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment