Last active
September 29, 2019 13:28
-
-
Save solaris33/6bdf7a5e8ef736f4599d6a2833f7ec2c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
####################################################################### | |
# Copyright (C) # | |
# 2016-2018 Shangtong Zhang([email protected]) # | |
# 2016 Kenta Shimada([email protected]) # | |
# Permission given to modify the code as long as you keep this # | |
# declaration at the top # | |
####################################################################### | |
# Refactoring : solaris33 | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.table import Table | |
matplotlib.use('Agg') | |
WORLD_SIZE = 4 | |
# up, left, down, right | |
ACTIONS = [np.array([0, -1]), | |
np.array([-1, 0]), | |
np.array([0, 1]), | |
np.array([1, 0])] | |
ACTION_PROB = 0.25 | |
def is_terminal(state): | |
x, y = state | |
return (x == 0 and y == 0) or (x == WORLD_SIZE - 1 and y == WORLD_SIZE - 1) | |
def step(state, action): | |
if is_terminal(state): | |
return state, 0 | |
next_state = (np.array(state) + action).tolist() | |
x, y = next_state | |
if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE: | |
next_state = state | |
reward = -1 | |
return next_state, reward | |
def draw_image(image): | |
fig, ax = plt.subplots() | |
ax.set_axis_off() | |
tb = Table(ax, bbox=[0, 0, 1, 1]) | |
nrows, ncols = image.shape | |
width, height = 1.0 / ncols, 1.0 / nrows | |
# Add cells | |
for (i, j), val in np.ndenumerate(image): | |
tb.add_cell(i, j, width, height, text=val, | |
loc='center', facecolor='white') | |
# Row and column labels... | |
for i in range(len(image)): | |
tb.add_cell(i, -1, width, height, text=i+1, loc='right', | |
edgecolor='none', facecolor='none') | |
tb.add_cell(-1, i, width, height/2, text=i+1, loc='center', | |
edgecolor='none', facecolor='none') | |
ax.add_table(tb) | |
def draw_image_progress(image, iteration, save_figure_name=None): | |
fig, ax = plt.subplots() | |
fig.suptitle('iteration : ' + str(iteration), fontsize=11) | |
ax.set_axis_off() | |
tb = Table(ax, bbox=[0, 0, 1, 1]) | |
nrows, ncols = image.shape | |
width, height = 1.0 / ncols, 1.0 / nrows | |
# Add cells | |
for (i, j), val in np.ndenumerate(image): | |
tb.add_cell(i, j, width, height, text=val, | |
loc='center', facecolor='white') | |
# Row and column labels... | |
for i in range(len(image)): | |
tb.add_cell(i, -1, width, height, text=i+1, loc='right', | |
edgecolor='none', facecolor='none') | |
tb.add_cell(-1, i, width, height/2, text=i+1, loc='center', | |
edgecolor='none', facecolor='none') | |
ax.add_table(tb) | |
if save_figure_name is not None: | |
plt.savefig(save_figure_name) | |
plt.show() | |
plt.close() | |
def compute_state_value(in_place=True, discount=1.0): | |
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE)) | |
iteration = 0 | |
while True: | |
if in_place: | |
state_values = new_state_values | |
else: | |
state_values = new_state_values.copy() | |
old_state_values = state_values.copy() | |
for i in range(WORLD_SIZE): | |
for j in range(WORLD_SIZE): | |
value = 0 | |
for action in ACTIONS: | |
(next_i, next_j), reward = step([i, j], action) | |
value += ACTION_PROB * (reward + discount * state_values[next_i, next_j]) | |
new_state_values[i, j] = value | |
max_delta_value = abs(old_state_values - new_state_values).max() | |
if max_delta_value < 1e-4: | |
break | |
iteration += 1 | |
return new_state_values, iteration | |
def compute_state_value_progress(in_place=True, discount=1.0, max_iteration=1): | |
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE)) | |
iteration = 0 | |
while iteration < max_iteration: | |
if in_place: | |
state_values = new_state_values | |
else: | |
state_values = new_state_values.copy() | |
old_state_values = state_values.copy() | |
for i in range(WORLD_SIZE): | |
for j in range(WORLD_SIZE): | |
value = 0 | |
for action in ACTIONS: | |
(next_i, next_j), reward = step([i, j], action) | |
value += ACTION_PROB * (reward + discount * state_values[next_i, next_j]) | |
new_state_values[i, j] = value | |
max_delta_value = abs(old_state_values - new_state_values).max() | |
if max_delta_value < 1e-4: | |
break | |
iteration += 1 | |
return new_state_values, iteration | |
def figure_4_1(): | |
values, sync_iteration = compute_state_value_progress(False, 1.0, 0) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration) | |
values, sync_iteration = compute_state_value_progress(False, 1.0, 1) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration) | |
values, sync_iteration = compute_state_value_progress(False, 1.0, 2) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration) | |
values, sync_iteration = compute_state_value_progress(False, 1.0, 3) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration) | |
values, sync_iteration = compute_state_value_progress(False, 1.0, 10) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration) | |
# While the author suggests using in-place iterative policy evaluation, | |
# Figure 4.1 actually uses out-of-place version. | |
_, asycn_iteration = compute_state_value(in_place=True) | |
values, sync_iteration = compute_state_value(in_place=False) | |
draw_image_progress(np.round(values, decimals=2), sync_iteration, 'figure_4_1.png') | |
print('In-place: {} iterations'.format(asycn_iteration)) | |
print('Synchronous: {} iterations'.format(sync_iteration)) | |
if __name__ == '__main__': | |
figure_4_1() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment