Skip to content

Instantly share code, notes, and snippets.

@nowl
Last active September 22, 2019 23:48
Show Gist options
  • Save nowl/4752490 to your computer and use it in GitHub Desktop.
Save nowl/4752490 to your computer and use it in GitHub Desktop.
Monte Carlo Tree Search implemented in Common Lisp to play Tic-Tac-Toe. Adapted from some of the example python code found on www.mcts.ai
;; adapted to lisp from the python code found here:
;; http://www.mcts.ai/?q=code/python
(defgeneric do-move (state move))
(defgeneric get-random-move (state))
(defgeneric get-moves (state))
(defgeneric get-result (state player-just-moved))
(defgeneric clone (state))
(defgeneric print-state (state))
(defgeneric player-just-moved (state))
(defstruct oxo-state
(player-just-moved 2 :type fixnum)
(board (make-array '(3 3) :initial-element 0)
:type array))
(defmethod player-just-moved ((state oxo-state))
(oxo-state-player-just-moved state))
(defmethod clone ((state oxo-state))
(let ((new-state (make-oxo-state)))
;; copy board
(loop for i below (array-dimension (oxo-state-board state) 0) do
(loop for j below (array-dimension (oxo-state-board state) 1) do
(setf (aref (oxo-state-board new-state) i j)
(aref (oxo-state-board state) i j))))
;; copy other
(setf (oxo-state-player-just-moved new-state)
(oxo-state-player-just-moved state))
new-state))
(defmethod print-state ((state oxo-state))
(let ((board (oxo-state-board state)))
(loop for j below (array-dimension board 1) do
(loop for i below (array-dimension board 0) do
(format t "~a" (aref board i j)))
(format t "~%"))))
(defmethod get-moves ((state oxo-state))
(destructuring-bind (h w) (array-dimensions (oxo-state-board state))
(let (moves)
(loop for y below h do
(loop for x below w do
(when (= (aref (oxo-state-board state) y x) 0)
(push (list x y) moves))))
moves)))
(defmethod get-random-move ((state oxo-state))
(destructuring-bind (h w) (array-dimensions (oxo-state-board state))
(loop with y-start = (mt19937:random h) for y-count below h do
(loop with x-start = (mt19937:random w) for x-count below w do
(let ((x (mod (+ x-start x-count) w))
(y (mod (+ y-start y-count) h)))
(when (= (aref (oxo-state-board state) y x) 0)
(return-from get-random-move (list x y))))))))
(defmethod get-result ((state oxo-state) player-just-moved)
(symbol-macrolet ((board (oxo-state-board state)))
(destructuring-bind (h w) (array-dimensions board)
;; check columns
(loop for y below h do
(when (= (aref board y 0)
(aref board y 1)
(aref board y 2))
(if (= player-just-moved
(aref board y 0))
(return-from get-result 1.0)
(return-from get-result 0))))
;; check rows
(loop for x below w do
(when (= (aref board 0 x)
(aref board 1 x)
(aref board 2 x))
(if (= player-just-moved
(aref board 0 x))
(return-from get-result 1.0)
(return-from get-result 0))))
;; check two diagonals
(when (or (= (aref board 0 0)
(aref board 1 1)
(aref board 2 2))
(= (aref board 0 2)
(aref board 1 1)
(aref board 2 0)))
(if (= player-just-moved (aref board 1 1))
(return-from get-result 1.0)
(return-from get-result 0)))))
(if (null (get-moves state))
0.5
0))
(defmethod do-move ((state oxo-state) move)
(symbol-macrolet ((board (oxo-state-board state))
(player-just-moved (oxo-state-player-just-moved state)))
(destructuring-bind (x y) move
(assert (= (aref board y x) 0))
(setf player-just-moved (- 3 player-just-moved)
(aref board y x) player-just-moved))))
(defun random-choice (list)
(when list
(let ((choice (mt19937:random (length list))))
(nth choice list))))
(defstruct node
move
parent-node
(child-nodes nil :type list)
(wins 0)
(visits 0 :type fixnum)
player-just-moved
remaining-moves)
(defun create-node (move parent state)
(make-node :move move
:parent-node parent
:player-just-moved (oxo-state-player-just-moved state)
:remaining-moves (get-moves state)))
(defun update-node (node result)
(incf (node-visits node))
(incf (node-wins node) result))
(defparameter *uctk* 1)
(defun ucb1-formula (node)
(symbol-macrolet ((wins (node-wins node))
(visits (node-visits node)))
(+ (/ wins visits)
(* *uctk* (sqrt (* 2 (/ (log visits) visits)))))))
(defun uct-select-child (node)
(symbol-macrolet ((children (node-child-nodes node)))
(setf children
(sort children #'> :key #'ucb1-formula))
(first children)))
(defun uct (root-state itermax)
(let ((root-node (create-node nil nil root-state)))
(loop for iter below itermax do
(let ((node root-node)
(state (clone root-state)))
;; select
(loop while (and (not (null (node-child-nodes node)))
(null (node-remaining-moves node))) do
(setf node (uct-select-child node))
(do-move state (node-move node)))
;; expand
(let ((move (random-choice (node-remaining-moves node))))
(when move
(do-move state move)
(let ((child (create-node move node state)))
(push child (node-child-nodes node))
(setf (node-remaining-moves node)
(delete move (node-remaining-moves node) :test #'equal))
(setf node child))))
;; rollout
(loop as move = (get-random-move state) while move do
(do-move state move))
;; backpropagate
(loop while node do
(let ((result (get-result state (node-player-just-moved node))))
(update-node node result)
(setf node (node-parent-node node))))))
#|
(loop for child in (node-child-nodes root-node) do
(format t "win percent: ~a%, ~a visits, ~a move~%"
(* 100 (/ (node-wins child)
(node-visits child)))
(node-visits child)
(node-move child)))
|#
(node-move (first (sort (node-child-nodes root-node) #'> :key #'node-visits)))))
(defparameter *root-state*
(make-oxo-state :board (make-array '(3 3)
:initial-contents '((0 0 0)
(0 0 0)
(0 0 0)))))
(defun play (state player-1-iters player-2-iters)
(declare (optimize (debug 3)))
(let ((state (clone state)))
(loop while (and (not (null (get-moves state)))
(= 0 (get-result state 1))
(= 0 (get-result state 2))) do
(let ((move
(cond ((= (player-just-moved state) 2)
;;(format t "player 1 moving~%")
(uct state player-1-iters))
(t ;;(format t "player 2 moving~%")
(uct state player-2-iters)))))
(do-move state move))
;;(print-state state)
)
(cond ((= 1 (get-result state 1))
(format t "player 1 wins~%")
1)
((= 1 (get-result state 2))
(format t "player 2 wins~%")
2)
(t (format t "tie~%")
0))))
(defun play-times (state times player-1-iters player-2-iters)
(let ((player-1-wins 0)
(player-2-wins 0)
(ties 0))
(loop for trial below times do
(ecase (play state player-1-iters player-2-iters)
(1 (incf player-1-wins))
(2 (incf player-2-wins))
(0 (incf ties))))
(format t "player 1 win percentage ~a%~%" (* 100 (/ (float player-1-wins) times)))
(format t "player 2 win percentage ~a%~%" (* 100 (/ (float player-2-wins) times)))))
;; test with something like:
;; (play-times *root-state* 100 10 5000)
;; or to give the first player an advantage
;; (play-times *root-state* 100 5000 10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment