Last active
May 6, 2022 05:46
-
-
Save tyfkda/9bdd79004661fe5561c326c01cdd8c44 to your computer and use it in GitHub Desktop.
TensorFlow.js 上で6x6リバーシ対戦
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const kDirections = [[1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1]] | |
| class State { | |
| constructor(pieces, enemyPieces, depth) { | |
| // 石の初期配置 | |
| if (pieces == null || enemyPieces == null) { | |
| pieces = [...Array(36)].fill(0) | |
| pieces[14] = pieces[21] = 1 | |
| enemyPieces = [...Array(36)].fill(0) | |
| enemyPieces[15] = enemyPieces[20] = 1 | |
| } | |
| // 連続パスによる終了 | |
| this.passEnd = false | |
| // 石の配置 | |
| this.pieces = pieces | |
| this.enemyPieces = enemyPieces | |
| this.depth = depth || 0 | |
| } | |
| // 石の数の取得 | |
| pieceCount(pieces) { | |
| let count = 0 | |
| for (let i = 0; i < pieces.length; ++i) | |
| if (pieces[i] === 1) | |
| count += 1 | |
| return count | |
| } | |
| // 負けかどうか | |
| isLose() { | |
| return this.isDone() && this.pieceCount(this.pieces) < this.pieceCount(this.enemyPieces) | |
| } | |
| // 引き分けかどうか | |
| isDraw() { | |
| return this.isDone() && this.pieceCount(this.pieces) === this.pieceCount(this.enemyPieces) | |
| } | |
| // ゲーム終了かどうか | |
| isDone() { | |
| return this.pieceCount(this.pieces) + this.pieceCount(this.enemyPieces) === 36 || this.passEnd | |
| } | |
| // 次の状態の取得 | |
| next(action) { | |
| const state = new State(Array.from(this.pieces), Array.from(this.enemyPieces), this.depth + 1) | |
| if (action !== 36) | |
| state.isLegalActionXy(action % 6, (action / 6) | 0, true) | |
| const w = state.pieces | |
| state.pieces = state.enemyPieces | |
| state.enemyPieces = w | |
| // 2回連続パス判定 | |
| if (action === 36) { | |
| const acts = state.legalActions() | |
| if (acts.length === 1 && acts[0] === 36) | |
| state.passEnd = true | |
| } | |
| return state | |
| } | |
| // 合法手のリストの取得 | |
| legalActions() { | |
| const actions = [] | |
| for (let j = 0; j < 6; ++j) | |
| for (let i = 0; i < 6; ++i) | |
| if (this.isLegalActionXy(i, j, false)) | |
| actions.push(i + j * 6) | |
| if (actions.length === 0) | |
| actions.push(36) // パス | |
| return actions | |
| } | |
| // 任意のマスが合法手かどうか | |
| isLegalActionXy(x, y, flip) { | |
| // 任意のマスの任意の方向が合法手かどうか | |
| const isLegalActionXyDxy = (x, y, dx, dy) => { | |
| // 1つ目 相手の石 | |
| x += dx | |
| y += dy | |
| if (y < 0 || 5 < y || x < 0 || 5 < x || | |
| this.enemyPieces[x + y * 6] !== 1) | |
| return false | |
| // 2つ目以降 | |
| for (let j = 0; j < 6; ++j) { | |
| // 空 | |
| if (y < 0 || 5 < y || x < 0 || 5 < x || (this.enemyPieces[x + y * 6] === 0 && this.pieces[x + y * 6] === 0)) | |
| return false | |
| // 自分の石 | |
| if (this.pieces[x + y * 6] === 1) { | |
| // 反転 | |
| if (flip) { | |
| for (let i = 0; i < 6; ++i) { | |
| x -= dx | |
| y -= dy | |
| if (this.pieces[x + y * 6] === 1) | |
| return true | |
| this.pieces[x + y * 6] = 1 | |
| this.enemyPieces[x + y * 6] = 0 | |
| } | |
| } | |
| return true | |
| } | |
| // 相手の石 | |
| x += dx | |
| y += dy | |
| } | |
| return false | |
| } | |
| // 空きなし | |
| if (this.enemyPieces[x + y * 6] === 1 || this.pieces[x + y * 6] === 1) | |
| return false | |
| // 石を置く | |
| if (flip) | |
| this.pieces[x + y * 6] = 1 | |
| // 任意の位置が合法手かどうか | |
| let flag = false | |
| for (const [dx, dy] of kDirections) { | |
| if (isLegalActionXyDxy(x, y, dx, dy)) { | |
| flag = true | |
| } | |
| } | |
| return flag | |
| } | |
| // 先手かどうか | |
| isFirstPlayer() { | |
| return this.depth % 2 === 0 | |
| } | |
| // 文字列表示 | |
| toString() { | |
| const ox = this.isFirstPlayer() ? ['o', 'x'] : ['x', 'o'] | |
| let str = '' | |
| for (let i = 0; i < 36; ++i) { | |
| if (this.pieces[i] === 1) | |
| str += ox[0] | |
| else if (this.enemyPieces[i] === 1) | |
| str += ox[1] | |
| else | |
| str += '-' | |
| if (i % 6 === 5) | |
| str += '\n' | |
| } | |
| return str | |
| } | |
| } | |
| module.exports = { | |
| State, | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'use strict' | |
| const {State} = require('./game') | |
| const {pvMctsAction} = require('./pv_mcts') | |
| class Game { | |
| constructor(model) { | |
| // ゲーム状態の生成 | |
| this.state = new State() | |
| // PV MCTSで行動選択を行う関数の生成 | |
| this.nextAction = pvMctsAction(model, 0.0) | |
| // キャンバスの生成 | |
| // this.c = tk.Canvas(self, width = 240, height = 240, highlightthickness = 0) | |
| // this.c.bind('<Button-1>', this.turnOfHuman) | |
| // this.c.pack() | |
| // 描画の更新 | |
| // this.onDraw() | |
| } | |
| // 人間のターン | |
| turnOfHuman(action) { | |
| // // ゲーム終了時 | |
| // if (this.state.isDone()) { | |
| // this.state = new State() | |
| // this.onDraw() | |
| // return | |
| // } | |
| // // 先手でない時 | |
| // if (!this.state.is_first_player()) | |
| // return | |
| // // クリック位置を行動に変換 | |
| // x = int(event.x/40) | |
| // y = int(event.y/40) | |
| // if (x < 0 || 5 < x || y < 0 || 5 < y) // 範囲外 | |
| // return | |
| // action = x + y * 6 | |
| // 合法手でない時 | |
| const legalActions = this.state.legalActions() | |
| if (legalActions.length === 1 && legalActions[0] === 36) | |
| action = 36 // パス | |
| if (action !== 36 && legalActions.indexOf(action) < 0) | |
| return | |
| // 次の状態の取得 | |
| this.state = this.state.next(action) | |
| this.onDraw() | |
| // AIのターン | |
| // this.master.after(1, this.turnOfAi) | |
| } | |
| // AIのターン | |
| turnOfAi() { | |
| // ゲーム終了時 | |
| if (this.state.isDone()) | |
| return | |
| // 行動の取得 | |
| const action = this.nextAction(this.state) | |
| // 次の状態の取得 | |
| this.state = this.state.next(action) | |
| this.onDraw() | |
| } | |
| // 石の描画 | |
| drawPiece(index, first_player) { | |
| const x = (index%6)*40+5 | |
| const y = int(index/6)*40+5 | |
| if (first_player) | |
| this.c.create_oval(x, y, x+30, y+30, width = 1.0, outline = '#000000', fill = '#C2272D') | |
| else | |
| this.c.create_oval(x, y, x+30, y+30, width = 1.0, outline = '#000000', fill = '#FFFFFF') | |
| } | |
| // 描画の更新 | |
| onDraw() { | |
| // this.c.delete('all') | |
| // this.c.create_rectangle(0, 0, 240, 240, width = 0.0, fill = '#C69C6C') | |
| // for (i in range(1, 8)) { | |
| // this.c.create_line(0, i*40, 240, i*40, width = 1.0, fill = '#000000') | |
| // this.c.create_line(i*40, 0, i*40, 240, width = 1.0, fill = '#000000') | |
| // } | |
| // for (i in range(36)) { | |
| // if (this.state.pieces[i] === 1) | |
| // this.drawPiece(i, this.state.is_first_player()) | |
| // if (this.state.enemyPieces[i] === 1) | |
| // this.drawPiece(i, !this.state.is_first_player()) | |
| // } | |
| console.log(this.state.toString()) | |
| } | |
| } | |
| module.exports = { | |
| Game, | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'use restrict' | |
| // import | |
| const tf = require('@tensorflow/tfjs') | |
| // require('@tensorflow/tfjs-node') // requireするとエラー:TypeError: backend.reshape is not a function | |
| tf.env().set('IS_NODE', false) // Suppress warning. | |
| const {Game} = require('./human_play') | |
| const readline = require('readline') | |
| const getsAsync = (() => { | |
| const readInterface = readline.createInterface({ | |
| input: process.stdin, | |
| output: process.stdout, | |
| }) | |
| return (msg) => new Promise( | |
| (resolve) => readInterface.question( | |
| msg, inputString => resolve(inputString))) | |
| })() | |
| async function main() { | |
| // load model | |
| const path = 'http://localhost:8080/tfjsmodel/model.json' | |
| const model = await tf.loadLayersModel(path) | |
| // let state = new State() | |
| // console.log(state.toString()) | |
| // console.log(state.legalActions()) | |
| // state = state.next(9) | |
| // console.log(state.toString()) | |
| // predict(model, state) | |
| const game = new Game(model) | |
| while (!game.state.isDone()) { | |
| if (game.state.isFirstPlayer()) { | |
| game.turnOfAi() | |
| } else { | |
| const acts = game.state.legalActions() | |
| console.log(`legalActions: ${acts.join(', ')}`) | |
| for (;;) { | |
| const line = await getsAsync('action? ') | |
| if (isNaN(line)) | |
| continue | |
| const action = Number(line) | 0 | |
| if (acts.indexOf(action) >= 0) { | |
| game.turnOfHuman(action) | |
| break | |
| } | |
| } | |
| } | |
| } | |
| if (game.state.isDraw()) { | |
| console.log('Draw') | |
| } else { | |
| const lose = game.state.isFirstPlayer() ^ game.state.isLose() | |
| console.log(lose ? 'Lose' : 'Win!') | |
| } | |
| process.exit(0) | |
| } | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "name": "tfjstest", | |
| "version": "1.0.0", | |
| "description": "", | |
| "main": "index.js", | |
| "scripts": { | |
| "test": "echo \"Error: no test specified\" && exit 1" | |
| }, | |
| "author": "", | |
| "license": "ISC", | |
| "dependencies": { | |
| "@tensorflow/tfjs": "<2.0", | |
| "@tensorflow/tfjs-node": "^3.16.0" | |
| }, | |
| "devDependencies": { | |
| "http-server": "^14.1.0" | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 'use strict' | |
| const tf = require('@tensorflow/tfjs') | |
| // パラメータの準備 | |
| const DN_INPUT_SHAPE = [6, 6, 2] // 入力シェイプ | |
| const PV_EVALUATE_COUNT = 50 // 1推論あたりのシミュレーション回数(本家は1600) | |
| function randomChoice(values, ratios) { | |
| const total = tf.sum(ratios) | |
| let r = Math.random() * total | |
| for (let i = 1; i < values.length; ++i) { | |
| r -= ratios[i] | |
| if (r < 0) | |
| return values[i] | |
| } | |
| return values[0] | |
| } | |
| // 推論 | |
| function predict(model, state) { | |
| // 推論のための入力データのシェイプの変換 | |
| const [a, b, c] = DN_INPUT_SHAPE | |
| const x = tf.tensor([state.pieces, state.enemyPieces]) | |
| .reshape([c, a, b]).transpose([1, 2, 0]).reshape([1, a, b, c]) | |
| // 推論 | |
| const y = model.predict(x, /*batch_size=*/1) | |
| // 方策の取得 | |
| const yPolicyArray = y[0].arraySync() | |
| let policies = state.legalActions().map((action) => { // 合法手のみ | |
| return yPolicyArray[0][action] | |
| }) | |
| const s = tf.sum(policies).arraySync() | |
| if (s > 0) { | |
| policies = tf.div(policies, s) // 合計1の確率分布に変換 | |
| } | |
| // 価値の取得 | |
| const value = y[1].arraySync()[0][0] | |
| return {policies, value} | |
| } | |
| // ノードのリストを試行回数のリストに変換 | |
| function nodesToScores(nodes) { | |
| return nodes.map(c => c.n) | |
| } | |
| // モンテカルロ木探索のノードの定義 | |
| class Node { | |
| // ノードの初期化 | |
| constructor(state, p) { | |
| this.state = state // 状態 | |
| this.p = p // 方策 | |
| this.w = 0 // 累計価値 | |
| this.n = 0 // 試行回数 | |
| this.childNodes = null // 子ノード群 | |
| } | |
| // 局面の価値の計算 | |
| evaluate(model) { | |
| // ゲーム終了時 | |
| if (this.state.isDone()) { | |
| // 勝敗結果で価値を取得 | |
| const value = this.state.isLose() ? -1 : 0 | |
| // 累計価値と試行回数の更新 | |
| this.w += value | |
| this.n += 1 | |
| return value | |
| } | |
| // 子ノードが存在しない時 | |
| if (!this.childNodes) { | |
| // ニューラルネットワークの推論で方策と価値を取得 | |
| const {policies, value} = predict(model, this.state) | |
| // 累計価値と試行回数の更新 | |
| this.w += value | |
| this.n += 1 | |
| // 子ノードの展開 | |
| this.childNodes = [] | |
| const acts = this.state.legalActions() | |
| for (let i = 0; i < acts.length; ++i) { | |
| const action = acts[i] | |
| const policy = policies[i] | |
| this.childNodes.push(new Node(this.state.next(action), policy)) | |
| } | |
| return value | |
| } | |
| // 子ノードが存在する時 | |
| else { | |
| // アーク評価値が最大の子ノードの評価で価値を取得 | |
| const value = -this.nextChildNode().evaluate(model) | |
| // 累計価値と試行回数の更新 | |
| this.w += value | |
| this.n += 1 | |
| return value | |
| } | |
| } | |
| // アーク評価値が最大の子ノードを取得 | |
| nextChildNode() { | |
| // アーク評価値の計算 | |
| const C_PUCT = 1.0 | |
| const t = tf.sum(nodesToScores(this.childNodes)) | |
| const pucbValues = [] | |
| for (const childNode of this.childNodes) { | |
| pucbValues.push((childNode.n ? -childNode.w / childNode.n : 0.0) + | |
| C_PUCT * childNode.p * Math.sqrt(t) / (1 + childNode.n)) | |
| } | |
| // アーク評価値が最大の子ノードを返す | |
| return this.childNodes[tf.argMax(pucbValues).arraySync()] | |
| } | |
| } | |
| // モンテカルロ木探索のスコアの取得 | |
| function pvMctsScores(model, state, temperature) { | |
| // 現在の局面のノードの作成 | |
| const rootNode = new Node(state, 0) | |
| // 複数回の評価の実行 | |
| for (let i = 0; i < PV_EVALUATE_COUNT; ++i) | |
| rootNode.evaluate(model) | |
| // 合法手の確率分布 | |
| let scores = nodesToScores(rootNode.childNodes) | |
| if (temperature === 0) { // 最大値のみ1 | |
| const action = tf.argMax(scores) | |
| scores = tf.zeros([scores.length]) | |
| scores[action] = 1 | |
| } else { // ボルツマン分布でバラつき付加 | |
| scores = boltzman(scores, temperature) | |
| } | |
| return scores | |
| } | |
| // モンテカルロ木探索で行動選択 | |
| function pvMctsAction(model, temperature) { | |
| temperature ||= 0 | |
| function actionFunc(state) { | |
| const scores = pvMctsScores(model, state, temperature) | |
| // return np.random.choice(state.legalActions(), p=scores) | |
| return randomChoice(state.legalActions(), scores) | |
| } | |
| return actionFunc | |
| } | |
| // ボルツマン分布 | |
| function boltzman(xs, temperature) { | |
| xs = xs.map(x => x ** (1 / temperature)) | |
| const s = xs.sum() | |
| return xs.map(x => x / s) | |
| } | |
| module.exports = { | |
| pvMctsAction, | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // gist名指定用 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment