Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active August 16, 2019 06:39
Show Gist options
  • Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.
Save kastnerkyle/9db1e88569c4358f11304dcdce05c9ab to your computer and use it in GitHub Desktop.
MCTS tictactoe, play 2 against each other, or play against it yourself
# Based on tutorial from https://jeffbradberry.com/posts/2015/09/intro-to-monte-carlo-tree-search/
# Author: Kyle Kastner
# License: BSD 3-Clause
from __future__ import print_function
import random
import copy
import numpy as np
import time
import argparse
import sys
global_random = np.random.RandomState(1989)
class Board(object):
def __init__(self):
self.player_symbols = ["X", "O"]
def start(self):
board = [" "] * 9
return tuple(board)
def current_player(self, board):
player = None
player_counts = [0, 0]
for n in range(len(self.player_symbols)):
all_syms = [b for b in board if b == self.player_symbols[n]]
player_counts[n] = len(all_syms)
if player_counts[0] == player_counts[1]:
return 1
else:
return 2
def is_available(self, board, move):
if board[move] != " ":
return False
else:
return True
def legal_moves(self, board):
move_opts = [0, 1, 2, 3, 4, 5, 6, 7, 8]
return [mo for mo in move_opts if self.is_available(board, mo)]
def next_state(self, board, move):
new_board = copy.copy(list(board))
player = self.current_player(board)
new_board[move] = self.player_symbols[player - 1]
return tuple(new_board)
def is_complete(self, board_history):
# -1 tie
# 0 continue
# 1 player1 win
# 2 player2 win
board = board_history[-1]
board_not_full = len([b for b in board if b != " "]) != len(board)
game_won = False
winner = ""
# simple check for wins
# horizontal
if board[0] == board[1] == board[2] and board[0] != " ":
game_won = True
winner = board[0]
elif board[3] == board[4] == board[5] and board[3] != " ":
game_won = True
winner = board[3]
elif board[6] == board[7] == board[8] and board[6] != " ":
game_won = True
winner = board[6]
# vertical
elif board[0] == board[3] == board[6] and board[0] != " ":
game_won = True
winner = board[0]
elif board[1] == board[4] == board[7] and board[1] != " ":
game_won = True
winner = board[1]
elif board[2] == board[5] == board[8] and board[2] != " ":
game_won = True
winner = board[2]
# diagonal
elif board[0] == board[4] == board[8] and board[0] != " ":
game_won = True
winner = board[0]
elif board[2] == board[4] == board[6] and board[2] != " ":
game_won = True
winner = board[2]
if board_not_full:
if game_won:
if winner == self.player_symbols[0]:
return 1
else:
return 2
else:
return 0
elif not board_not_full:
if game_won:
if winner == self.player_symbols[0]:
return 1
else:
return 2
else:
return -1
def draw(self, board):
tmp_board = copy.copy(list(board))
for i in range(len(tmp_board)):
if tmp_board[i] == " ":
tmp_board[i] = "({})".format(i)
else:
tmp_board[i] = " {} ".format(tmp_board[i])
print(' | |')
print(' ' + tmp_board[6] + ' | ' + tmp_board[7] + ' | ' + tmp_board[8])
print(' | |')
print('----------------')
print(' | |')
print(' ' + tmp_board[3] + ' | ' + tmp_board[4] + ' | ' + tmp_board[5])
print(' | |')
print('----------------')
print(' | |')
print(' ' + tmp_board[0] + ' | ' + tmp_board[1] + ' | ' + tmp_board[2])
print(' | |')
print("")
class MCTS(object):
def __init__(self, board, state_history=None, runtime_s=10, horizon=100,
ucb_weight=1.4,
verbose=False):
# policy can be ucb1, random
self.board = board
self.runtime_s = runtime_s
self.horizon = horizon
self.ucb_weight = ucb_weight
self.plays = {}
self.rewards = {}
self.verbose = verbose
if state_history is None:
self.state_history = [board.start()]
else:
self.state_history = state_history
def update(self, state):
self.state_history.append(state)
def estimate(self, policy="uct"):
if policy == "uct":
pass
elif policy == "random":
pass
else:
raise ValueError("Unknown value policy={}".format(policy))
self.max_depth = 0
state = self.state_history[-1]
player = self.board.current_player(self.state_history[-1])
moves = self.board.legal_moves(self.state_history[-1])
if len(moves) == 0:
return
elif len(moves) == 1:
return moves[0]
games = 0
start_time = time.time()
last_t = start_time
print("Player {}({})'s turn".format(player, self.board.player_symbols[player - 1]))
while time.time() - start_time < self.runtime_s:
this_t = time.time()
if this_t - last_t > 1:
last_t = this_t
print("Calculating...")
if policy == "uct":
self.rollout_uct()
elif policy == "random":
self.rollout_random()
else:
raise ValueError("Unknown value policy={}".format(policy))
games += 1
print("")
end_time = time.time() - start_time
moves_states = [(m, self.board.next_state(state, m)) for m in moves]
if self.verbose:
print("Number of sim games {}, total time {}".format(games, end_time))
percent, move = max(
(self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)), p) for p, S in moves_states)
if self.verbose:
for x in sorted(
((100 * self.rewards.get((player, S), 0) / float(self.plays.get((player, S), 1)),
self.plays.get((player, S), 1),
self.rewards.get((player, S), 0),
p) for p, S in moves_states), reverse=True):
print("{3}: {0:.2f}% ({2} / {1})".format(*x))
return move
def rollout_uct(self):
plays, rewards = self.plays, self.rewards
exploration = self.ucb_weight
visited_states = {}
states_copy = copy.copy(self.state_history)
state = states_copy[-1]
player = self.board.current_player(states_copy[-1])
expand = True
for t in range(self.horizon + 1):
moves = self.board.legal_moves(states_copy[-1])
moves_states = [(p, self.board.next_state(state, p)) for p in moves]
if all(plays.get((player, S)) for p, S in moves_states):
log_total = np.log(sum(plays[(player, S)] for p, S in moves_states))
basic_move_triples = [((rewards[(player, S)] / float(plays[(player, S)])) + exploration * np.sqrt(log_total / float(plays[(player, S)])), p, S) for p, S in moves_states]
value, move, state = max(basic_move_triples)
else:
# random move
global_random.shuffle(moves)
move = moves[0]
state = self.board.next_state(states_copy[-1], move)
states_copy.append(state)
if expand and (player, state) not in self.plays:
expand = False
self.plays[(player, state)] = 0
self.rewards[(player, state)] = 0
if t > self.max_depth:
self.max_depth = t
visited_states[(player, state)] = None
player = self.board.current_player(states_copy[-1])
complete = self.board.is_complete(states_copy)
if complete != 0:
break
for player, state in visited_states.keys():
if (player, state) not in self.plays:
continue
self.plays[(player, state)] += 1
if player == complete:
self.rewards[(player, state)] += 1
# ties
if complete == -1:
self.rewards[(player, state)] += 1
def rollout_random(self):
visited_states = {}
states_copy = copy.copy(self.state_history)
state = states_copy[-1]
player = self.board.current_player(states_copy[-1])
expand = True
for t in range(self.horizon + 1):
moves = self.board.legal_moves(states_copy[-1])
global_random.shuffle(moves)
move = moves[0]
state = self.board.next_state(states_copy[-1], move)
states_copy.append(state)
if expand and (player, state) not in self.plays:
expand = False
self.plays[(player, state)] = 0
self.rewards[(player, state)] = 0
if t > self.max_depth:
self.max_depth = t
visited_states[(player, state)] = None
player = self.board.current_player(states_copy[-1])
complete = self.board.is_complete(states_copy)
if complete != 0:
break
for player, state in visited_states.keys():
if (player, state) not in self.plays:
continue
self.plays[(player, state)] += 1
if player == complete:
self.rewards[(player, state)] += 1
# ties
if complete == -1:
self.rewards[(player, state)] += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo of player/MCTS for TIC TAC TOE. Use -i to play against the machine, or -a to make the machine play itself.")
parser.add_argument("-i", "--interactive", action="store_true",
default=False,
help="Play against the computer yourself. For a beatable computer, try -r .1 or lower")
parser.add_argument("-a", "--automatic", action="store_true",
default=False,
help="Play the computer against itself.")
parser.add_argument("-r", "--roundtime", type=float, default=3)
parser.add_argument("-n", "--no_verbose", action="store_false", default=True)
args = parser.parse_args()
automatic = args.automatic
interactive = args.interactive
if not automatic and not interactive:
parser.print_help()
sys.exit(1)
if automatic and interactive:
print("Must choose either -i or -a, not both!")
sys.exit(1)
while True:
board = Board()
roundtime = args.roundtime
verbose = args.no_verbose
mcts1 = MCTS(board, runtime_s=roundtime, verbose=verbose)
mcts2 = MCTS(board, runtime_s=roundtime, verbose=verbose)
board_history = [board.start()]
if not args.interactive:
# randomly switch player order (still called 1, 2 but symbols swap)
if global_random.randint(18888) % 2:
board.player_symbols = ["O", "X"]
print("Player {} ({}), Player {}, ({})".format(1, board.player_symbols[0], 2, board.player_symbols[1]))
# inner game loop
while True:
complete = board.is_complete(board_history)
board.draw(board_history[-1])
if complete == 0:
player = board.current_player(board_history[-1])
if player == 1:
if args.interactive:
move = "100"
get_move = True
while get_move:
print("Human player {}, next move? (0-8)".format(1))
move = raw_input()
move_opts = ["0", "1", "2", "3", "4", "5", "6", "7", "8"]
if board.is_available(board_history[-1], int(move)) and str(move) in move_opts:
get_move = False
break
move = int(move)
else:
move = mcts1.estimate(policy="uct")
else:
move = mcts2.estimate(policy="uct")
"""
# True random, naive moves
moves = board.legal_moves(board_history[-1])
global_random.shuffle(moves)
move = moves[0]
"""
next_board = board.next_state(board_history[-1], move)
board_history.append(next_board)
mcts1.update(next_board)
mcts2.update(next_board)
else:
# evaluate
print("Game over!")
if complete == 1:
print("Player 1 wins.")
elif complete == 2:
print("Player 2 wins.")
elif complete == -1:
print("Tie game.")
break
print("Play again (y/n)?")
choice = raw_input()
if str(choice) != "y":
print("Thanks for playing!")
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment