Skip to content

Instantly share code, notes, and snippets.

@tyfkda
Last active May 6, 2022 05:46
Show Gist options
  • Save tyfkda/9bdd79004661fe5561c326c01cdd8c44 to your computer and use it in GitHub Desktop.
Save tyfkda/9bdd79004661fe5561c326c01cdd8c44 to your computer and use it in GitHub Desktop.
TensorFlow.js 上で6x6リバーシ対戦
const kDirections = [[1, 0], [1, 1], [0, 1], [-1, 1], [-1, 0], [-1, -1], [0, -1], [1, -1]]
class State {
constructor(pieces, enemyPieces, depth) {
// 石の初期配置
if (pieces == null || enemyPieces == null) {
pieces = [...Array(36)].fill(0)
pieces[14] = pieces[21] = 1
enemyPieces = [...Array(36)].fill(0)
enemyPieces[15] = enemyPieces[20] = 1
}
// 連続パスによる終了
this.passEnd = false
// 石の配置
this.pieces = pieces
this.enemyPieces = enemyPieces
this.depth = depth || 0
}
// 石の数の取得
pieceCount(pieces) {
let count = 0
for (let i = 0; i < pieces.length; ++i)
if (pieces[i] === 1)
count += 1
return count
}
// 負けかどうか
isLose() {
return this.isDone() && this.pieceCount(this.pieces) < this.pieceCount(this.enemyPieces)
}
// 引き分けかどうか
isDraw() {
return this.isDone() && this.pieceCount(this.pieces) === this.pieceCount(this.enemyPieces)
}
// ゲーム終了かどうか
isDone() {
return this.pieceCount(this.pieces) + this.pieceCount(this.enemyPieces) === 36 || this.passEnd
}
// 次の状態の取得
next(action) {
const state = new State(Array.from(this.pieces), Array.from(this.enemyPieces), this.depth + 1)
if (action !== 36)
state.isLegalActionXy(action % 6, (action / 6) | 0, true)
const w = state.pieces
state.pieces = state.enemyPieces
state.enemyPieces = w
// 2回連続パス判定
if (action === 36) {
const acts = state.legalActions()
if (acts.length === 1 && acts[0] === 36)
state.passEnd = true
}
return state
}
// 合法手のリストの取得
legalActions() {
const actions = []
for (let j = 0; j < 6; ++j)
for (let i = 0; i < 6; ++i)
if (this.isLegalActionXy(i, j, false))
actions.push(i + j * 6)
if (actions.length === 0)
actions.push(36) // パス
return actions
}
// 任意のマスが合法手かどうか
isLegalActionXy(x, y, flip) {
// 任意のマスの任意の方向が合法手かどうか
const isLegalActionXyDxy = (x, y, dx, dy) => {
// 1つ目 相手の石
x += dx
y += dy
if (y < 0 || 5 < y || x < 0 || 5 < x ||
this.enemyPieces[x + y * 6] !== 1)
return false
// 2つ目以降
for (let j = 0; j < 6; ++j) {
// 空
if (y < 0 || 5 < y || x < 0 || 5 < x || (this.enemyPieces[x + y * 6] === 0 && this.pieces[x + y * 6] === 0))
return false
// 自分の石
if (this.pieces[x + y * 6] === 1) {
// 反転
if (flip) {
for (let i = 0; i < 6; ++i) {
x -= dx
y -= dy
if (this.pieces[x + y * 6] === 1)
return true
this.pieces[x + y * 6] = 1
this.enemyPieces[x + y * 6] = 0
}
}
return true
}
// 相手の石
x += dx
y += dy
}
return false
}
// 空きなし
if (this.enemyPieces[x + y * 6] === 1 || this.pieces[x + y * 6] === 1)
return false
// 石を置く
if (flip)
this.pieces[x + y * 6] = 1
// 任意の位置が合法手かどうか
let flag = false
for (const [dx, dy] of kDirections) {
if (isLegalActionXyDxy(x, y, dx, dy)) {
flag = true
}
}
return flag
}
// 先手かどうか
isFirstPlayer() {
return this.depth % 2 === 0
}
// 文字列表示
toString() {
const ox = this.isFirstPlayer() ? ['o', 'x'] : ['x', 'o']
let str = ''
for (let i = 0; i < 36; ++i) {
if (this.pieces[i] === 1)
str += ox[0]
else if (this.enemyPieces[i] === 1)
str += ox[1]
else
str += '-'
if (i % 6 === 5)
str += '\n'
}
return str
}
}
module.exports = {
State,
}
'use strict'
const {State} = require('./game')
const {pvMctsAction} = require('./pv_mcts')
class Game {
constructor(model) {
// ゲーム状態の生成
this.state = new State()
// PV MCTSで行動選択を行う関数の生成
this.nextAction = pvMctsAction(model, 0.0)
// キャンバスの生成
// this.c = tk.Canvas(self, width = 240, height = 240, highlightthickness = 0)
// this.c.bind('<Button-1>', this.turnOfHuman)
// this.c.pack()
// 描画の更新
// this.onDraw()
}
// 人間のターン
turnOfHuman(action) {
// // ゲーム終了時
// if (this.state.isDone()) {
// this.state = new State()
// this.onDraw()
// return
// }
// // 先手でない時
// if (!this.state.is_first_player())
// return
// // クリック位置を行動に変換
// x = int(event.x/40)
// y = int(event.y/40)
// if (x < 0 || 5 < x || y < 0 || 5 < y) // 範囲外
// return
// action = x + y * 6
// 合法手でない時
const legalActions = this.state.legalActions()
if (legalActions.length === 1 && legalActions[0] === 36)
action = 36 // パス
if (action !== 36 && legalActions.indexOf(action) < 0)
return
// 次の状態の取得
this.state = this.state.next(action)
this.onDraw()
// AIのターン
// this.master.after(1, this.turnOfAi)
}
// AIのターン
turnOfAi() {
// ゲーム終了時
if (this.state.isDone())
return
// 行動の取得
const action = this.nextAction(this.state)
// 次の状態の取得
this.state = this.state.next(action)
this.onDraw()
}
// 石の描画
drawPiece(index, first_player) {
const x = (index%6)*40+5
const y = int(index/6)*40+5
if (first_player)
this.c.create_oval(x, y, x+30, y+30, width = 1.0, outline = '#000000', fill = '#C2272D')
else
this.c.create_oval(x, y, x+30, y+30, width = 1.0, outline = '#000000', fill = '#FFFFFF')
}
// 描画の更新
onDraw() {
// this.c.delete('all')
// this.c.create_rectangle(0, 0, 240, 240, width = 0.0, fill = '#C69C6C')
// for (i in range(1, 8)) {
// this.c.create_line(0, i*40, 240, i*40, width = 1.0, fill = '#000000')
// this.c.create_line(i*40, 0, i*40, 240, width = 1.0, fill = '#000000')
// }
// for (i in range(36)) {
// if (this.state.pieces[i] === 1)
// this.drawPiece(i, this.state.is_first_player())
// if (this.state.enemyPieces[i] === 1)
// this.drawPiece(i, !this.state.is_first_player())
// }
console.log(this.state.toString())
}
}
module.exports = {
Game,
}
'use restrict'
// import
const tf = require('@tensorflow/tfjs')
// require('@tensorflow/tfjs-node') // requireするとエラー:TypeError: backend.reshape is not a function
tf.env().set('IS_NODE', false) // Suppress warning.
const {Game} = require('./human_play')
const readline = require('readline')
const getsAsync = (() => {
const readInterface = readline.createInterface({
input: process.stdin,
output: process.stdout,
})
return (msg) => new Promise(
(resolve) => readInterface.question(
msg, inputString => resolve(inputString)))
})()
async function main() {
// load model
const path = 'http://localhost:8080/tfjsmodel/model.json'
const model = await tf.loadLayersModel(path)
// let state = new State()
// console.log(state.toString())
// console.log(state.legalActions())
// state = state.next(9)
// console.log(state.toString())
// predict(model, state)
const game = new Game(model)
while (!game.state.isDone()) {
if (game.state.isFirstPlayer()) {
game.turnOfAi()
} else {
const acts = game.state.legalActions()
console.log(`legalActions: ${acts.join(', ')}`)
for (;;) {
const line = await getsAsync('action? ')
if (isNaN(line))
continue
const action = Number(line) | 0
if (acts.indexOf(action) >= 0) {
game.turnOfHuman(action)
break
}
}
}
}
if (game.state.isDraw()) {
console.log('Draw')
} else {
const lose = game.state.isFirstPlayer() ^ game.state.isLose()
console.log(lose ? 'Lose' : 'Win!')
}
process.exit(0)
}
main()
{
"name": "tfjstest",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"author": "",
"license": "ISC",
"dependencies": {
"@tensorflow/tfjs": "<2.0",
"@tensorflow/tfjs-node": "^3.16.0"
},
"devDependencies": {
"http-server": "^14.1.0"
}
}
'use strict'
const tf = require('@tensorflow/tfjs')
// パラメータの準備
const DN_INPUT_SHAPE = [6, 6, 2] // 入力シェイプ
const PV_EVALUATE_COUNT = 50 // 1推論あたりのシミュレーション回数(本家は1600)
function randomChoice(values, ratios) {
const total = tf.sum(ratios)
let r = Math.random() * total
for (let i = 1; i < values.length; ++i) {
r -= ratios[i]
if (r < 0)
return values[i]
}
return values[0]
}
// 推論
function predict(model, state) {
// 推論のための入力データのシェイプの変換
const [a, b, c] = DN_INPUT_SHAPE
const x = tf.tensor([state.pieces, state.enemyPieces])
.reshape([c, a, b]).transpose([1, 2, 0]).reshape([1, a, b, c])
// 推論
const y = model.predict(x, /*batch_size=*/1)
// 方策の取得
const yPolicyArray = y[0].arraySync()
let policies = state.legalActions().map((action) => { // 合法手のみ
return yPolicyArray[0][action]
})
const s = tf.sum(policies).arraySync()
if (s > 0) {
policies = tf.div(policies, s) // 合計1の確率分布に変換
}
// 価値の取得
const value = y[1].arraySync()[0][0]
return {policies, value}
}
// ノードのリストを試行回数のリストに変換
function nodesToScores(nodes) {
return nodes.map(c => c.n)
}
// モンテカルロ木探索のノードの定義
class Node {
// ノードの初期化
constructor(state, p) {
this.state = state // 状態
this.p = p // 方策
this.w = 0 // 累計価値
this.n = 0 // 試行回数
this.childNodes = null // 子ノード群
}
// 局面の価値の計算
evaluate(model) {
// ゲーム終了時
if (this.state.isDone()) {
// 勝敗結果で価値を取得
const value = this.state.isLose() ? -1 : 0
// 累計価値と試行回数の更新
this.w += value
this.n += 1
return value
}
// 子ノードが存在しない時
if (!this.childNodes) {
// ニューラルネットワークの推論で方策と価値を取得
const {policies, value} = predict(model, this.state)
// 累計価値と試行回数の更新
this.w += value
this.n += 1
// 子ノードの展開
this.childNodes = []
const acts = this.state.legalActions()
for (let i = 0; i < acts.length; ++i) {
const action = acts[i]
const policy = policies[i]
this.childNodes.push(new Node(this.state.next(action), policy))
}
return value
}
// 子ノードが存在する時
else {
// アーク評価値が最大の子ノードの評価で価値を取得
const value = -this.nextChildNode().evaluate(model)
// 累計価値と試行回数の更新
this.w += value
this.n += 1
return value
}
}
// アーク評価値が最大の子ノードを取得
nextChildNode() {
// アーク評価値の計算
const C_PUCT = 1.0
const t = tf.sum(nodesToScores(this.childNodes))
const pucbValues = []
for (const childNode of this.childNodes) {
pucbValues.push((childNode.n ? -childNode.w / childNode.n : 0.0) +
C_PUCT * childNode.p * Math.sqrt(t) / (1 + childNode.n))
}
// アーク評価値が最大の子ノードを返す
return this.childNodes[tf.argMax(pucbValues).arraySync()]
}
}
// モンテカルロ木探索のスコアの取得
function pvMctsScores(model, state, temperature) {
// 現在の局面のノードの作成
const rootNode = new Node(state, 0)
// 複数回の評価の実行
for (let i = 0; i < PV_EVALUATE_COUNT; ++i)
rootNode.evaluate(model)
// 合法手の確率分布
let scores = nodesToScores(rootNode.childNodes)
if (temperature === 0) { // 最大値のみ1
const action = tf.argMax(scores)
scores = tf.zeros([scores.length])
scores[action] = 1
} else { // ボルツマン分布でバラつき付加
scores = boltzman(scores, temperature)
}
return scores
}
// モンテカルロ木探索で行動選択
function pvMctsAction(model, temperature) {
temperature ||= 0
function actionFunc(state) {
const scores = pvMctsScores(model, state, temperature)
// return np.random.choice(state.legalActions(), p=scores)
return randomChoice(state.legalActions(), scores)
}
return actionFunc
}
// ボルツマン分布
function boltzman(xs, temperature) {
xs = xs.map(x => x ** (1 / temperature))
const s = xs.sum()
return xs.map(x => x / s)
}
module.exports = {
pvMctsAction,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment