Last active
August 17, 2017 14:28
-
-
Save drscotthawley/ca5e06767c77a5d7c09398dddcfe358a to your computer and use it in GitHub Desktop.
My implementation of a temporal difference policy method for playing Tic Tac Toe, after seeing Shlomo Bauer speak at the Brentwood A.I. meetup.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
# Uses temporal difference policy method | |
# See, e.g., http://www.cs.dartmouth.edu/~lorenzo/teaching/cs134/Archive/Spring2009/final/PengTao/final_report.pdf | |
# In this code, X plays randomly whereas O 'learns'. (Feel free to change that) | |
# Thus we expect O to outperform X eventually | |
# Author: Scott Hawley http://drscotthawley.github.io | |
# Unlimited License: Feel free to use any or all of this code however you like. | |
import numpy as np | |
LENGTH = 3 # one dim of game board. board is LENGTH x LENGTH | |
EMPTY_VAL = 0 # value to represent empty cell | |
X_VAL = 1 # value to represent 'X' | |
O_VAL = 2 # value to represent 'O' | |
TIE_VAL = 3 # value for tie game | |
DEBUG = 0 # higher integer value means more messages | |
SECRET_BACKDOOR = "JOSHUA" # obligatory humor. does nothing. | |
def debug_msg(msg, level=1): # logging routine | |
if (level <= DEBUG) and (DEBUG != 0): | |
print(msg) | |
def enumerate_states(): | |
# generate all possible configurations | |
# from https://stackoverflow.com/questions/7466429/generate-a-list-of-all-unique-tic-tac-toe-boards | |
numstates = (O_VAL-EMPTY_VAL+1)**(LENGTH*LENGTH) | |
debug_msg("numstates = "+str(numstates)) | |
states = np.zeros( (numstates, LENGTH, LENGTH), dtype=np.int8 ) | |
for s in range(numstates): | |
c = s | |
for bi in range( LENGTH * LENGTH): # bi = board index | |
val = c % LENGTH | |
i = int(bi / LENGTH) | |
j = int(bi % LENGTH) | |
states[s][i][j] = val | |
c //= 3 | |
return states | |
def is_end_state(board): # board is a LENGTH x LENGTH array, i.e. one particular state | |
# returns 0 for not end state, X_VAL if X wins, and O_VAL if O wins, TIE_VAL for tie game | |
# check rows | |
for i in range(0,LENGTH): | |
win = True | |
for j in range(1,LENGTH): | |
if (board[i][j] != board[i][0] ): | |
win = False | |
break; | |
if win: | |
debug_msg(" row win along i = "+str(i)) | |
return board[i][0] | |
# check columns | |
for j in range(0,LENGTH): | |
win = True | |
for i in range(1,LENGTH): | |
if (board[i][j] != board[0][j] ): | |
win = False | |
break; | |
if win: | |
debug_msg("column win along j = "+str(j)) | |
return board[0][j] | |
# check 'forward' diagonal | |
win = True | |
for i in range(1,LENGTH): | |
if (board[i][i] != board[0,0]): | |
win = False | |
break; | |
if win: | |
debug_msg("forward diag win") | |
return board[0][0] | |
# check 'backward' diagonal | |
win = True | |
for i in range(1,LENGTH): | |
if (board[i][LENGTH-1-i] != board[0,LENGTH-1]): | |
win = False | |
break; | |
if win: | |
debug_msg("backward diag win") | |
return board[0][LENGTH-1] | |
# no winner | |
if (EMPTY_VAL in board): # board's not full, and no winner yet | |
return 0 | |
return TIE_VAL # no winner; tie game | |
def make_random_move(board, player_val): | |
free_moves = np.where( EMPTY_VAL == board ) | |
debug_msg("free moves = "+str(free_moves), level=4) | |
num_moves = free_moves[0].shape[0] | |
debug_msg("num_moves = "+str(num_moves), level=4) | |
if (num_moves) > 0: | |
choice_ind = np.random.randint(num_moves) | |
debug_msg("choice_ind = "+str(choice_ind)) | |
i = free_moves[0][choice_ind] | |
j = free_moves[1][choice_ind] | |
board[i][j] = player_val | |
else: | |
debug_msg("No moves possible",level=3) | |
return board | |
def make_smart_move(board, player_val, states, state_vals): | |
free_moves = np.where( EMPTY_VAL == board ) | |
debug_msg("free moves = "+str(free_moves), level=4) | |
num_moves = free_moves[0].shape[0] | |
debug_msg("num_moves = "+str(num_moves), level=4) | |
if (num_moves) > 0: | |
# get the state values of the resulting boards corresponding to each possible move | |
move_state_vals = [] | |
move_state_choices = [] | |
move_boards = [] | |
for choice_ind in range(num_moves): | |
new_board = np.copy(board) | |
i = free_moves[0][choice_ind] | |
j = free_moves[1][choice_ind] | |
new_board[i][j] = player_val | |
move_boards.append(new_board) | |
new_state_number = get_state_number(new_board) | |
move_state_choices.append( new_state_number ) | |
move_state_vals.append(state_vals[new_state_number]) | |
# now choose the state that has the highest value | |
debug_msg(" move_state_vals = "+str(move_state_vals), level=1) | |
choice_ind = np.argmax( move_state_vals ) | |
debug_msg("smart choice_ind = "+str(choice_ind), level=1) | |
#print("which corresponds to board ",move_boards[choice_ind]) | |
if (np.argmax( move_state_vals ) == np.argmin( move_state_vals )): # no clear choice | |
debug_msg(" so making random choice", level=1) | |
make_random_move(board, player_val) | |
else: | |
board = move_boards[choice_ind] | |
else: | |
debug_msg("No moves possible",level=3) | |
return board | |
def get_state_number(board): # turns the board into a hash using base-3 arithmetic | |
numcells = LENGTH*LENGTH | |
hash_loc = 0 | |
for i in range(numcells): | |
hash_loc += np.ndarray.flatten(board)[i] * 3**(numcells-1-i) # treat array as base-3 digits | |
return hash_loc | |
# the crux of the method. modify the values associate with states that appeared | |
# in the game's history, by essentially propagating the final score backwards in time. | |
def score_history(state_vals, end_state, player, history, alpha = 0.5): | |
if (end_state == player ): | |
end_score = 1 | |
elif (end_state == TIE_VAL): | |
end_score = 0 | |
else: | |
end_score = -1 | |
#print(" end_score = ",end_score) | |
histlen = len(history) | |
h = histlen - 1 | |
s = history[h] | |
state_vals[ s ] = end_score # score the game | |
#print("Scoring: h, s, state_vals[s] = ",h,s, state_vals[s]) | |
for h in range(histlen-2, -1, -1): # propagate backward in time | |
sprime = history[h+1] | |
s = history[h] | |
state_vals[s] = state_vals[s] + alpha * (state_vals[sprime] - state_vals[s]) | |
#print("Scoring: h, s, state_vals[s] = ",h,s, state_vals[s]) | |
return state_vals | |
############## MAIN CODE FOLLOWS ################### | |
np.random.seed(1) # uncomment to get the same outcomes each time | |
# Storage for states and values. Note: the indices in these two arrays may not refer to the same objects, but that's ok. | |
states = enumerate_states() | |
o_state_vals = np.zeros(3**(LENGTH*LENGTH)) | |
# x_state_vals = np.zeros(222222222) # You can have X learn too, but then you get TIE GAME almost all the time | |
alpha = 0.9 # learning rate, >= 0. setting this to zero means O plays randomly too | |
outerloops = 2 | |
for do_it_again in range(outerloops): # cuz this is fun! in the later loops, the system has already been trained | |
numgames = 10000 # https://www.youtube.com/watch?v=Et38wtr7ih4 | |
x_wins = 0 # keeping track of number of wins by x & o, and tie games | |
o_wins = 0 | |
ties = 0 | |
for game in range(numgames): | |
if (do_it_again == outerloops-1) and (game == numgames-10): # show last games | |
DEBUG = 1 | |
debug_msg("===================================",level=1) | |
print("Game number ",game," of ",numgames,": ", end="") | |
end_state = 0 | |
turn = 0 | |
board = np.copy(states[0]) # empty board | |
history = [] | |
player = X_VAL | |
while ((0 == end_state) and (turn < LENGTH**2)): # while game's not over | |
turn += 1 | |
debug_msg('----------------------') | |
debug_msg("turn = "+str(turn)+", player = "+str(player)) | |
if (X_VAL == player): | |
board = make_random_move(board,player) | |
#board = make_smart_move(board, player, states, x_state_vals) # if both X and O learn, you get TIE GAME all the time | |
else: | |
board = make_smart_move(board, player, states, o_state_vals) | |
state_num = get_state_number(board) | |
history.append(state_num) | |
debug_msg(" state_num = "+str(state_num)+", new board = ",level=1) | |
debug_msg(str(board),level=1) | |
debug_msg("",level=1) | |
if (X_VAL == player): # switch players for next time | |
player = O_VAL | |
else: | |
player = X_VAL | |
end_state = is_end_state(board) | |
# Game over, man! | |
debug_msg("end_state = "+str(end_state)) | |
if (X_VAL == end_state): | |
print("X WINS!") | |
x_wins += 1 | |
elif (O_VAL == end_state): | |
print(" O WINS!") | |
o_wins += 1 | |
else: | |
print(" TIE GAME") | |
ties += 1 | |
debug_msg("history = "+str(history)) | |
o_state_vals = score_history(o_state_vals, end_state, O_VAL, history, alpha=alpha) # adjust state vals | |
#x_state_vals = score_history(x_state_vals, end_state, X_VAL, history, alpha=alpha) # adjust state vals | |
print("Final stats: x_wins, o_wins, ties = ",x_wins, o_wins, ties,". O won ",o_wins*100.0/numgames,"% of the last ",numgames," games") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment