Skip to content

Instantly share code, notes, and snippets.

@znxkznxk1030
Last active March 24, 2025 09:46
Show Gist options
  • Save znxkznxk1030/10465ac9ffeacce31c51ed657e6c2760 to your computer and use it in GitHub Desktop.
Save znxkznxk1030/10465ac9ffeacce31c51ed657e6c2760 to your computer and use it in GitHub Desktop.
rl-001.py
import torch
from torch import initial_seed
directs = [(1, 0), (-1, 0), (0, 1), (0, -1)] # [down, up, right, left]
inf = int(1e9)
def initialize_policy(width, height, terminals):
policy = torch.full((height, width, 4), 0.0)
for y in range(height):
for x in range(width):
if (y, x) in terminals:
continue
d = 0
for direct in directs:
ny = y + direct[0]
nx = x + direct[1]
if 0 <= ny < height and 0 <= nx < width:
d += 1
for i, direct in enumerate(directs):
ny = y + direct[0]
nx = x + direct[1]
if 0 <= ny < height and 0 <= nx < width:
policy[y, x, i] += 1 / d
return policy
def policy_evaluation(policy, v, threshold, reward, discount_rate, height, width, terminals):
diff = inf
k = 1
while diff > threshold:
diff = v_update(policy, v, reward, discount_rate, height, width, terminals)
k += 1
# print('--- k = ' + str(k) + ' ---')
# print(v)
return v
def v_update(policy, v, reward, discount_rate, height, width, terminals):
diff = 0
prev_v = v.clone()
for y in range(height):
for x in range(width):
if (y, x) in terminals:
continue
temp = prev_v[y, x]
new_value = 0
for i, direct in enumerate(directs):
ny = y + direct[0]
nx = x + direct[1]
p = 1 if (0 <= ny < height and 0 <= nx < width) else 0
prev_value = prev_v[ny, nx] if p else 0
new_value += policy[y, x, i] * p * (reward + discount_rate * prev_value)
v[y, x] = new_value
diff = max(diff, abs(temp - v[y, x]))
return diff
def policy_improvement(policy, v, height, width, terminals):
policy_stable = True
for y in range(height):
for x in range(width):
if (y, x) in terminals:
continue
max_value = -inf
vn = 0
for i, direct in enumerate(directs):
ny = y + direct[0]
nx = x + direct[1]
if 0 <= ny < height and 0 <= nx < width:
value = v[ny, nx]
if max_value < value:
max_value = value
vn = 1
elif max_value == value:
vn += 1
prev_policy = policy[y, x].clone()
for i, direct in enumerate(directs):
policy[y, x, i] = 0.0
ny = y + direct[0]
nx = x + direct[1]
if 0 <= ny < height and 0 <= nx < width:
value = v[ny, nx]
if value == max_value:
policy[y, x, i] = (1 / vn)
if not torch.equal(prev_policy, policy[y,x]):
policy_stable = False
return policy_stable, policy
def print_policy(policy, terminals):
# print(policy)
h = len(policy)
w = len(policy[0])
print('\t======= policy =======')
for _ in range(w):
print('-----', end='')
print()
for y in range(h):
print('', end='|')
for x in range(w):
if (y, x) in terminals:
print(' X ', end='|')
continue
ways = ''
for i, direct in enumerate(directs):
if policy[y][x][i] == 0.0:
ways = ways + ' '
continue
if i == 0:
ways = ways + '↓'
if i == 1:
ways = ways + '↑'
if i == 2:
ways = ways + '→'
if i == 3:
ways = ways + '←'
print(ways, end='|')
print()
for _ in range(w):
print('-----', end='')
print()
def main():
n = 10
threshold = 0.0001
discount_rate = 1
reward = -1
terminals = ((0, 0), (1, 2), (3, 4), (n-1, n-1))
policy = initialize_policy(n, n, terminals)
initial_v_grid = torch.full((n, n), 0.0)
print_policy(policy, terminals)
t = 1
while True:
v_grid = policy_evaluation(policy, initial_v_grid, threshold, reward, discount_rate, n, n, terminals)
policy_stable, policy = policy_improvement(policy, v_grid, n, n, terminals)
print(t)
t += 1
if policy_stable:
break
print_policy(policy, terminals)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment