Last active
March 24, 2025 09:46
-
-
Save znxkznxk1030/10465ac9ffeacce31c51ed657e6c2760 to your computer and use it in GitHub Desktop.
rl-001.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import initial_seed | |
directs = [(1, 0), (-1, 0), (0, 1), (0, -1)] # [down, up, right, left] | |
inf = int(1e9) | |
def initialize_policy(width, height, terminals): | |
policy = torch.full((height, width, 4), 0.0) | |
for y in range(height): | |
for x in range(width): | |
if (y, x) in terminals: | |
continue | |
d = 0 | |
for direct in directs: | |
ny = y + direct[0] | |
nx = x + direct[1] | |
if 0 <= ny < height and 0 <= nx < width: | |
d += 1 | |
for i, direct in enumerate(directs): | |
ny = y + direct[0] | |
nx = x + direct[1] | |
if 0 <= ny < height and 0 <= nx < width: | |
policy[y, x, i] += 1 / d | |
return policy | |
def policy_evaluation(policy, v, threshold, reward, discount_rate, height, width, terminals): | |
diff = inf | |
k = 1 | |
while diff > threshold: | |
diff = v_update(policy, v, reward, discount_rate, height, width, terminals) | |
k += 1 | |
# print('--- k = ' + str(k) + ' ---') | |
# print(v) | |
return v | |
def v_update(policy, v, reward, discount_rate, height, width, terminals): | |
diff = 0 | |
prev_v = v.clone() | |
for y in range(height): | |
for x in range(width): | |
if (y, x) in terminals: | |
continue | |
temp = prev_v[y, x] | |
new_value = 0 | |
for i, direct in enumerate(directs): | |
ny = y + direct[0] | |
nx = x + direct[1] | |
p = 1 if (0 <= ny < height and 0 <= nx < width) else 0 | |
prev_value = prev_v[ny, nx] if p else 0 | |
new_value += policy[y, x, i] * p * (reward + discount_rate * prev_value) | |
v[y, x] = new_value | |
diff = max(diff, abs(temp - v[y, x])) | |
return diff | |
def policy_improvement(policy, v, height, width, terminals): | |
policy_stable = True | |
for y in range(height): | |
for x in range(width): | |
if (y, x) in terminals: | |
continue | |
max_value = -inf | |
vn = 0 | |
for i, direct in enumerate(directs): | |
ny = y + direct[0] | |
nx = x + direct[1] | |
if 0 <= ny < height and 0 <= nx < width: | |
value = v[ny, nx] | |
if max_value < value: | |
max_value = value | |
vn = 1 | |
elif max_value == value: | |
vn += 1 | |
prev_policy = policy[y, x].clone() | |
for i, direct in enumerate(directs): | |
policy[y, x, i] = 0.0 | |
ny = y + direct[0] | |
nx = x + direct[1] | |
if 0 <= ny < height and 0 <= nx < width: | |
value = v[ny, nx] | |
if value == max_value: | |
policy[y, x, i] = (1 / vn) | |
if not torch.equal(prev_policy, policy[y,x]): | |
policy_stable = False | |
return policy_stable, policy | |
def print_policy(policy, terminals): | |
# print(policy) | |
h = len(policy) | |
w = len(policy[0]) | |
print('\t======= policy =======') | |
for _ in range(w): | |
print('-----', end='') | |
print() | |
for y in range(h): | |
print('', end='|') | |
for x in range(w): | |
if (y, x) in terminals: | |
print(' X ', end='|') | |
continue | |
ways = '' | |
for i, direct in enumerate(directs): | |
if policy[y][x][i] == 0.0: | |
ways = ways + ' ' | |
continue | |
if i == 0: | |
ways = ways + '↓' | |
if i == 1: | |
ways = ways + '↑' | |
if i == 2: | |
ways = ways + '→' | |
if i == 3: | |
ways = ways + '←' | |
print(ways, end='|') | |
print() | |
for _ in range(w): | |
print('-----', end='') | |
print() | |
def main(): | |
n = 10 | |
threshold = 0.0001 | |
discount_rate = 1 | |
reward = -1 | |
terminals = ((0, 0), (1, 2), (3, 4), (n-1, n-1)) | |
policy = initialize_policy(n, n, terminals) | |
initial_v_grid = torch.full((n, n), 0.0) | |
print_policy(policy, terminals) | |
t = 1 | |
while True: | |
v_grid = policy_evaluation(policy, initial_v_grid, threshold, reward, discount_rate, n, n, terminals) | |
policy_stable, policy = policy_improvement(policy, v_grid, n, n, terminals) | |
print(t) | |
t += 1 | |
if policy_stable: | |
break | |
print_policy(policy, terminals) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment