Last active
December 3, 2024 04:37
-
-
Save junpeitsuji/e3b2b766e788b8afb0c8d9942b4c3bdd to your computer and use it in GitHub Desktop.
強化学習のデモ。5x5のグリッド上で定義された簡単な迷路を解く。ただしQ学習の更新式が抜けているので228行目を修正する必要あり。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="ja"> | |
<head> | |
<meta charset="UTF-8"> | |
<title>強化学習</title> | |
</head> | |
<body> | |
<h1>強化学習で迷路を解く</h1> | |
<canvas id="canvas" width="512" height="512"></canvas> | |
<p>episode: <span id="episode">0</span></p> | |
<p>epsilon: <span id="epsilon">1.0</span></p> | |
<button onclick="start()">START</button> | |
<script> | |
const canvas = document.querySelector('#canvas'); | |
const context = canvas.getContext('2d'); | |
// エージェントの座標 | |
let x = 0; | |
let y = 0; | |
// ゴールの座標 | |
let goal_x = 4; | |
let goal_y = 4; | |
// 描画領域の左上座標を定義 | |
let offset = 6; | |
let box_size = 100; | |
// 学習パラメータ | |
let alpha = 0.2; // 学習率 | |
let gamma = 0.9; // 利得の割引率 | |
let epsilon = 1; // ε-greedy戦略(確率εでランダム行動を選択する) | |
// エピソードの回数 | |
let episode = 0; | |
let down = () => { | |
y += 1; | |
} | |
let up = () => { | |
y -= 1; | |
} | |
let left = () => { | |
x -= 1; | |
} | |
let right = () => { | |
x += 1; | |
} | |
let move = []; // 各xyにおける可能な行動一覧(関数のリスト) move[x][y] | |
let move_name = []; // 各xyにおける可能な行動名一覧(文字列のリスト) move_name[x][y] | |
let q = []; // Q学習で必要なQテーブル q[x][y][a] | |
let reward = []; // 各状態に対応する報酬 r[x][y] | |
for(let x=0; x<5; x++){ | |
move.push([]); | |
move_name.push([]); | |
q.push([]); | |
reward.push([]); | |
for(let y=0; y<5; y++){ | |
move[x].push([]); | |
move_name[x].push([]); | |
q[x].push([]); | |
if(x == goal_x && y == goal_y){ | |
reward[x].push(10.0); | |
} | |
else{ | |
reward[x].push(0.0); | |
} | |
if(y >= 1){ | |
move[x][y].push(up); | |
move_name[x][y].push("up"); | |
q[x][y].push(0.0); | |
} | |
if(x <= 3){ | |
move[x][y].push(right); | |
move_name[x][y].push("right"); | |
q[x][y].push(0.0); | |
} | |
if(y <= 3){ | |
move[x][y].push(down); | |
move_name[x][y].push("down"); | |
q[x][y].push(0.0); | |
} | |
if(x >= 1){ | |
move[x][y].push(left); | |
move_name[x][y].push("left"); | |
q[x][y].push(0.0); | |
} | |
} | |
} | |
// 描画 | |
let draw = () => { | |
context.strokeStyle = 'black'; | |
context.fillStyle = 'white'; | |
// 背景を描画 | |
context.fillRect(0, 0, 512, 512); | |
// 迷路の枠を表示 | |
context.lineWidth = 4; | |
for(let i=0; i<5; i++){ | |
for(let j=0; j<5; j++){ | |
context.strokeRect(offset+i*box_size, offset+j*box_size, box_size, box_size); | |
} | |
} | |
// Qテーブルの状態を可視化 | |
for(let x=0; x<5; x++){ | |
for(let y=0; y<5; y++){ | |
for(let a=0; a<move_name[x][y].length; a++){ | |
let val = Math.floor(255 * (10.0 - q[x][y][a]) / 10.0); | |
if(move_name[x][y][a] == "up"){ | |
context.fillStyle = "white"; | |
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size-0.05*box_size, 0.2*box_size, 0.1*box_size); | |
context.fillStyle = "rgb(255, "+val+", "+val+")"; | |
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.05*box_size, 0.2*box_size, 0.3*box_size); | |
} | |
if(move_name[x][y][a] == "right"){ | |
context.fillStyle = "white"; | |
context.fillRect(offset+x*box_size+0.95*box_size, offset+y*box_size+0.4*box_size, 0.1*box_size, 0.2*box_size); | |
context.fillStyle = "rgb(255, "+val+", "+val+")"; | |
context.fillRect(offset+x*box_size+0.65*box_size, offset+y*box_size+0.4*box_size, 0.3*box_size, 0.2*box_size); | |
} | |
if(move_name[x][y][a] == "down"){ | |
context.fillStyle = "white"; | |
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.95*box_size, 0.2*box_size, 0.1*box_size); | |
context.fillStyle = "rgb(255, "+val+", "+val+")"; | |
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.65*box_size, 0.2*box_size, 0.3*box_size); | |
} | |
if(move_name[x][y][a] == "left"){ | |
context.fillStyle = "white"; | |
context.fillRect(offset+x*box_size-0.05*box_size, offset+y*box_size+0.4*box_size, 0.1*box_size, 0.2*box_size); | |
context.fillStyle = "rgb(255, "+val+", "+val+")"; | |
context.fillRect(offset+x*box_size+0.05*box_size, offset+y*box_size+0.4*box_size, 0.3*box_size, 0.2*box_size); | |
} | |
} | |
} | |
} | |
// ゴールを描画 | |
context.fillStyle = 'red'; | |
context.beginPath(); | |
context.arc(offset+0.5*box_size+goal_x*box_size, offset+0.5*box_size+goal_y*box_size, 0.4*box_size, 0, 2 * Math.PI); | |
context.closePath(); | |
context.fill(); | |
// エージェントを描画 | |
context.fillStyle = 'black'; | |
context.beginPath(); | |
context.arc(offset+0.5*box_size+x*box_size, offset+0.5*box_size+y*box_size, 0.4*box_size, 0, 2 * Math.PI); | |
context.closePath(); | |
context.fill(); | |
} | |
let next = () => { | |
x = 0; | |
y = 0; | |
episode++; | |
epsilon *= 0.9; | |
if(epsilon < 0.05){ | |
epsilon = 0.05; | |
} | |
let p_epsilon = document.querySelector('#epsilon'); | |
p_epsilon.innerHTML = ""+epsilon; | |
let p_episode = document.querySelector('#episode'); | |
p_episode.innerHTML = ""+episode; | |
} | |
// Qテーブルを更新 | |
let learn = (x, y, nx, ny, a, r) => { | |
let max_q = 0; | |
for(let a2=0; a2<move[nx][ny].length; a2++){ | |
if(max_q < q[nx][ny][a2]){ | |
max_q = q[nx][ny][a2]; | |
} | |
} | |
// ここに更新式を入れる(TODO) | |
// q[x][y][a] = | |
} | |
// 各ステップの更新処理 | |
let update = () => { | |
if(x == goal_x && y == goal_y){ | |
next(); | |
} | |
else { | |
let rand = Math.random(); | |
let cx = x; | |
let cy = y; | |
let r = 0; | |
let a = 0; | |
let num_of_actions = move[cx][cy].length; | |
if(rand < epsilon){ | |
// ランダムに行動を決定 | |
a = Math.floor( Math.random() * num_of_actions ); | |
} | |
else { | |
// Greedy: | |
// Qテーブル値が最大になる行動を選択 | |
let max = q[cx][cy][0]; | |
for(let a2=0; a2<num_of_actions; a2++){ | |
if( max < q[cx][cy][a2] ){ | |
max = q[cx][cy][a2]; | |
a = a2; | |
} | |
} | |
} | |
// 行動を実行 | |
move[cx][cy][a](); | |
// 次の状態で報酬を受け取り | |
r = reward[x][y]; | |
// Qテーブルを更新 | |
learn(cx, cy, x, y, a, r); | |
} | |
draw(); | |
} | |
draw(); | |
let start = () => { | |
setInterval(() => { | |
update(); | |
}, 50); | |
} | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment