Skip to content

Instantly share code, notes, and snippets.

@junpeitsuji
Last active December 3, 2024 04:37
Show Gist options
  • Save junpeitsuji/e3b2b766e788b8afb0c8d9942b4c3bdd to your computer and use it in GitHub Desktop.
Save junpeitsuji/e3b2b766e788b8afb0c8d9942b4c3bdd to your computer and use it in GitHub Desktop.
強化学習のデモ。5x5のグリッド上で定義された簡単な迷路を解く。ただしQ学習の更新式が抜けているので228行目を修正する必要あり。
<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="UTF-8">
<title>強化学習</title>
</head>
<body>
<h1>強化学習で迷路を解く</h1>
<canvas id="canvas" width="512" height="512"></canvas>
<p>episode: <span id="episode">0</span></p>
<p>epsilon: <span id="epsilon">1.0</span></p>
<button onclick="start()">START</button>
<script>
const canvas = document.querySelector('#canvas');
const context = canvas.getContext('2d');
// エージェントの座標
let x = 0;
let y = 0;
// ゴールの座標
let goal_x = 4;
let goal_y = 4;
// 描画領域の左上座標を定義
let offset = 6;
let box_size = 100;
// 学習パラメータ
let alpha = 0.2; // 学習率
let gamma = 0.9; // 利得の割引率
let epsilon = 1; // ε-greedy戦略(確率εでランダム行動を選択する)
// エピソードの回数
let episode = 0;
let down = () => {
y += 1;
}
let up = () => {
y -= 1;
}
let left = () => {
x -= 1;
}
let right = () => {
x += 1;
}
let move = []; // 各xyにおける可能な行動一覧(関数のリスト) move[x][y]
let move_name = []; // 各xyにおける可能な行動名一覧(文字列のリスト) move_name[x][y]
let q = []; // Q学習で必要なQテーブル q[x][y][a]
let reward = []; // 各状態に対応する報酬 r[x][y]
for(let x=0; x<5; x++){
move.push([]);
move_name.push([]);
q.push([]);
reward.push([]);
for(let y=0; y<5; y++){
move[x].push([]);
move_name[x].push([]);
q[x].push([]);
if(x == goal_x && y == goal_y){
reward[x].push(10.0);
}
else{
reward[x].push(0.0);
}
if(y >= 1){
move[x][y].push(up);
move_name[x][y].push("up");
q[x][y].push(0.0);
}
if(x <= 3){
move[x][y].push(right);
move_name[x][y].push("right");
q[x][y].push(0.0);
}
if(y <= 3){
move[x][y].push(down);
move_name[x][y].push("down");
q[x][y].push(0.0);
}
if(x >= 1){
move[x][y].push(left);
move_name[x][y].push("left");
q[x][y].push(0.0);
}
}
}
// 描画
let draw = () => {
context.strokeStyle = 'black';
context.fillStyle = 'white';
// 背景を描画
context.fillRect(0, 0, 512, 512);
// 迷路の枠を表示
context.lineWidth = 4;
for(let i=0; i<5; i++){
for(let j=0; j<5; j++){
context.strokeRect(offset+i*box_size, offset+j*box_size, box_size, box_size);
}
}
// Qテーブルの状態を可視化
for(let x=0; x<5; x++){
for(let y=0; y<5; y++){
for(let a=0; a<move_name[x][y].length; a++){
let val = Math.floor(255 * (10.0 - q[x][y][a]) / 10.0);
if(move_name[x][y][a] == "up"){
context.fillStyle = "white";
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size-0.05*box_size, 0.2*box_size, 0.1*box_size);
context.fillStyle = "rgb(255, "+val+", "+val+")";
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.05*box_size, 0.2*box_size, 0.3*box_size);
}
if(move_name[x][y][a] == "right"){
context.fillStyle = "white";
context.fillRect(offset+x*box_size+0.95*box_size, offset+y*box_size+0.4*box_size, 0.1*box_size, 0.2*box_size);
context.fillStyle = "rgb(255, "+val+", "+val+")";
context.fillRect(offset+x*box_size+0.65*box_size, offset+y*box_size+0.4*box_size, 0.3*box_size, 0.2*box_size);
}
if(move_name[x][y][a] == "down"){
context.fillStyle = "white";
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.95*box_size, 0.2*box_size, 0.1*box_size);
context.fillStyle = "rgb(255, "+val+", "+val+")";
context.fillRect(offset+x*box_size+0.4*box_size, offset+y*box_size+0.65*box_size, 0.2*box_size, 0.3*box_size);
}
if(move_name[x][y][a] == "left"){
context.fillStyle = "white";
context.fillRect(offset+x*box_size-0.05*box_size, offset+y*box_size+0.4*box_size, 0.1*box_size, 0.2*box_size);
context.fillStyle = "rgb(255, "+val+", "+val+")";
context.fillRect(offset+x*box_size+0.05*box_size, offset+y*box_size+0.4*box_size, 0.3*box_size, 0.2*box_size);
}
}
}
}
// ゴールを描画
context.fillStyle = 'red';
context.beginPath();
context.arc(offset+0.5*box_size+goal_x*box_size, offset+0.5*box_size+goal_y*box_size, 0.4*box_size, 0, 2 * Math.PI);
context.closePath();
context.fill();
// エージェントを描画
context.fillStyle = 'black';
context.beginPath();
context.arc(offset+0.5*box_size+x*box_size, offset+0.5*box_size+y*box_size, 0.4*box_size, 0, 2 * Math.PI);
context.closePath();
context.fill();
}
let next = () => {
x = 0;
y = 0;
episode++;
epsilon *= 0.9;
if(epsilon < 0.05){
epsilon = 0.05;
}
let p_epsilon = document.querySelector('#epsilon');
p_epsilon.innerHTML = ""+epsilon;
let p_episode = document.querySelector('#episode');
p_episode.innerHTML = ""+episode;
}
// Qテーブルを更新
let learn = (x, y, nx, ny, a, r) => {
let max_q = 0;
for(let a2=0; a2<move[nx][ny].length; a2++){
if(max_q < q[nx][ny][a2]){
max_q = q[nx][ny][a2];
}
}
// ここに更新式を入れる(TODO)
// q[x][y][a] =
}
// 各ステップの更新処理
let update = () => {
if(x == goal_x && y == goal_y){
next();
}
else {
let rand = Math.random();
let cx = x;
let cy = y;
let r = 0;
let a = 0;
let num_of_actions = move[cx][cy].length;
if(rand < epsilon){
// ランダムに行動を決定
a = Math.floor( Math.random() * num_of_actions );
}
else {
// Greedy:
// Qテーブル値が最大になる行動を選択
let max = q[cx][cy][0];
for(let a2=0; a2<num_of_actions; a2++){
if( max < q[cx][cy][a2] ){
max = q[cx][cy][a2];
a = a2;
}
}
}
// 行動を実行
move[cx][cy][a]();
// 次の状態で報酬を受け取り
r = reward[x][y];
// Qテーブルを更新
learn(cx, cy, x, y, a, r);
}
draw();
}
draw();
let start = () => {
setInterval(() => {
update();
}, 50);
}
</script>
</body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment