Last active
September 22, 2019 23:48
-
-
Save nowl/4752490 to your computer and use it in GitHub Desktop.
Monte Carlo Tree Search implemented in Common Lisp to play Tic-Tac-Toe. Adapted from some of the example python code found on www.mcts.ai
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
;; adapted to lisp from the python code found here: | |
;; http://www.mcts.ai/?q=code/python | |
(defgeneric do-move (state move)) | |
(defgeneric get-random-move (state)) | |
(defgeneric get-moves (state)) | |
(defgeneric get-result (state player-just-moved)) | |
(defgeneric clone (state)) | |
(defgeneric print-state (state)) | |
(defgeneric player-just-moved (state)) | |
(defstruct oxo-state | |
(player-just-moved 2 :type fixnum) | |
(board (make-array '(3 3) :initial-element 0) | |
:type array)) | |
(defmethod player-just-moved ((state oxo-state)) | |
(oxo-state-player-just-moved state)) | |
(defmethod clone ((state oxo-state)) | |
(let ((new-state (make-oxo-state))) | |
;; copy board | |
(loop for i below (array-dimension (oxo-state-board state) 0) do | |
(loop for j below (array-dimension (oxo-state-board state) 1) do | |
(setf (aref (oxo-state-board new-state) i j) | |
(aref (oxo-state-board state) i j)))) | |
;; copy other | |
(setf (oxo-state-player-just-moved new-state) | |
(oxo-state-player-just-moved state)) | |
new-state)) | |
(defmethod print-state ((state oxo-state)) | |
(let ((board (oxo-state-board state))) | |
(loop for j below (array-dimension board 1) do | |
(loop for i below (array-dimension board 0) do | |
(format t "~a" (aref board i j))) | |
(format t "~%")))) | |
(defmethod get-moves ((state oxo-state)) | |
(destructuring-bind (h w) (array-dimensions (oxo-state-board state)) | |
(let (moves) | |
(loop for y below h do | |
(loop for x below w do | |
(when (= (aref (oxo-state-board state) y x) 0) | |
(push (list x y) moves)))) | |
moves))) | |
(defmethod get-random-move ((state oxo-state)) | |
(destructuring-bind (h w) (array-dimensions (oxo-state-board state)) | |
(loop with y-start = (mt19937:random h) for y-count below h do | |
(loop with x-start = (mt19937:random w) for x-count below w do | |
(let ((x (mod (+ x-start x-count) w)) | |
(y (mod (+ y-start y-count) h))) | |
(when (= (aref (oxo-state-board state) y x) 0) | |
(return-from get-random-move (list x y)))))))) | |
(defmethod get-result ((state oxo-state) player-just-moved) | |
(symbol-macrolet ((board (oxo-state-board state))) | |
(destructuring-bind (h w) (array-dimensions board) | |
;; check columns | |
(loop for y below h do | |
(when (= (aref board y 0) | |
(aref board y 1) | |
(aref board y 2)) | |
(if (= player-just-moved | |
(aref board y 0)) | |
(return-from get-result 1.0) | |
(return-from get-result 0)))) | |
;; check rows | |
(loop for x below w do | |
(when (= (aref board 0 x) | |
(aref board 1 x) | |
(aref board 2 x)) | |
(if (= player-just-moved | |
(aref board 0 x)) | |
(return-from get-result 1.0) | |
(return-from get-result 0)))) | |
;; check two diagonals | |
(when (or (= (aref board 0 0) | |
(aref board 1 1) | |
(aref board 2 2)) | |
(= (aref board 0 2) | |
(aref board 1 1) | |
(aref board 2 0))) | |
(if (= player-just-moved (aref board 1 1)) | |
(return-from get-result 1.0) | |
(return-from get-result 0))))) | |
(if (null (get-moves state)) | |
0.5 | |
0)) | |
(defmethod do-move ((state oxo-state) move) | |
(symbol-macrolet ((board (oxo-state-board state)) | |
(player-just-moved (oxo-state-player-just-moved state))) | |
(destructuring-bind (x y) move | |
(assert (= (aref board y x) 0)) | |
(setf player-just-moved (- 3 player-just-moved) | |
(aref board y x) player-just-moved)))) | |
(defun random-choice (list) | |
(when list | |
(let ((choice (mt19937:random (length list)))) | |
(nth choice list)))) | |
(defstruct node | |
move | |
parent-node | |
(child-nodes nil :type list) | |
(wins 0) | |
(visits 0 :type fixnum) | |
player-just-moved | |
remaining-moves) | |
(defun create-node (move parent state) | |
(make-node :move move | |
:parent-node parent | |
:player-just-moved (oxo-state-player-just-moved state) | |
:remaining-moves (get-moves state))) | |
(defun update-node (node result) | |
(incf (node-visits node)) | |
(incf (node-wins node) result)) | |
(defparameter *uctk* 1) | |
(defun ucb1-formula (node) | |
(symbol-macrolet ((wins (node-wins node)) | |
(visits (node-visits node))) | |
(+ (/ wins visits) | |
(* *uctk* (sqrt (* 2 (/ (log visits) visits))))))) | |
(defun uct-select-child (node) | |
(symbol-macrolet ((children (node-child-nodes node))) | |
(setf children | |
(sort children #'> :key #'ucb1-formula)) | |
(first children))) | |
(defun uct (root-state itermax) | |
(let ((root-node (create-node nil nil root-state))) | |
(loop for iter below itermax do | |
(let ((node root-node) | |
(state (clone root-state))) | |
;; select | |
(loop while (and (not (null (node-child-nodes node))) | |
(null (node-remaining-moves node))) do | |
(setf node (uct-select-child node)) | |
(do-move state (node-move node))) | |
;; expand | |
(let ((move (random-choice (node-remaining-moves node)))) | |
(when move | |
(do-move state move) | |
(let ((child (create-node move node state))) | |
(push child (node-child-nodes node)) | |
(setf (node-remaining-moves node) | |
(delete move (node-remaining-moves node) :test #'equal)) | |
(setf node child)))) | |
;; rollout | |
(loop as move = (get-random-move state) while move do | |
(do-move state move)) | |
;; backpropagate | |
(loop while node do | |
(let ((result (get-result state (node-player-just-moved node)))) | |
(update-node node result) | |
(setf node (node-parent-node node)))))) | |
#| | |
(loop for child in (node-child-nodes root-node) do | |
(format t "win percent: ~a%, ~a visits, ~a move~%" | |
(* 100 (/ (node-wins child) | |
(node-visits child))) | |
(node-visits child) | |
(node-move child))) | |
|# | |
(node-move (first (sort (node-child-nodes root-node) #'> :key #'node-visits))))) | |
(defparameter *root-state* | |
(make-oxo-state :board (make-array '(3 3) | |
:initial-contents '((0 0 0) | |
(0 0 0) | |
(0 0 0))))) | |
(defun play (state player-1-iters player-2-iters) | |
(declare (optimize (debug 3))) | |
(let ((state (clone state))) | |
(loop while (and (not (null (get-moves state))) | |
(= 0 (get-result state 1)) | |
(= 0 (get-result state 2))) do | |
(let ((move | |
(cond ((= (player-just-moved state) 2) | |
;;(format t "player 1 moving~%") | |
(uct state player-1-iters)) | |
(t ;;(format t "player 2 moving~%") | |
(uct state player-2-iters))))) | |
(do-move state move)) | |
;;(print-state state) | |
) | |
(cond ((= 1 (get-result state 1)) | |
(format t "player 1 wins~%") | |
1) | |
((= 1 (get-result state 2)) | |
(format t "player 2 wins~%") | |
2) | |
(t (format t "tie~%") | |
0)))) | |
(defun play-times (state times player-1-iters player-2-iters) | |
(let ((player-1-wins 0) | |
(player-2-wins 0) | |
(ties 0)) | |
(loop for trial below times do | |
(ecase (play state player-1-iters player-2-iters) | |
(1 (incf player-1-wins)) | |
(2 (incf player-2-wins)) | |
(0 (incf ties)))) | |
(format t "player 1 win percentage ~a%~%" (* 100 (/ (float player-1-wins) times))) | |
(format t "player 2 win percentage ~a%~%" (* 100 (/ (float player-2-wins) times))))) | |
;; test with something like: | |
;; (play-times *root-state* 100 10 5000) | |
;; or to give the first player an advantage | |
;; (play-times *root-state* 100 5000 10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment