Skip to content

Instantly share code, notes, and snippets.

@solaris33
Last active September 29, 2019 13:28
Show Gist options
  • Save solaris33/6bdf7a5e8ef736f4599d6a2833f7ec2c to your computer and use it in GitHub Desktop.
Save solaris33/6bdf7a5e8ef736f4599d6a2833f7ec2c to your computer and use it in GitHub Desktop.
#######################################################################
# Copyright (C) #
# 2016-2018 Shangtong Zhang([email protected]) #
# 2016 Kenta Shimada([email protected]) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
# Refactoring : solaris33
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.table import Table
matplotlib.use('Agg')
WORLD_SIZE = 4
# up, left, down, right
ACTIONS = [np.array([0, -1]),
np.array([-1, 0]),
np.array([0, 1]),
np.array([1, 0])]
ACTION_PROB = 0.25
def is_terminal(state):
x, y = state
return (x == 0 and y == 0) or (x == WORLD_SIZE - 1 and y == WORLD_SIZE - 1)
def step(state, action):
if is_terminal(state):
return state, 0
next_state = (np.array(state) + action).tolist()
x, y = next_state
if x < 0 or x >= WORLD_SIZE or y < 0 or y >= WORLD_SIZE:
next_state = state
reward = -1
return next_state, reward
def draw_image(image):
fig, ax = plt.subplots()
ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])
nrows, ncols = image.shape
width, height = 1.0 / ncols, 1.0 / nrows
# Add cells
for (i, j), val in np.ndenumerate(image):
tb.add_cell(i, j, width, height, text=val,
loc='center', facecolor='white')
# Row and column labels...
for i in range(len(image)):
tb.add_cell(i, -1, width, height, text=i+1, loc='right',
edgecolor='none', facecolor='none')
tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
edgecolor='none', facecolor='none')
ax.add_table(tb)
def draw_image_progress(image, iteration, save_figure_name=None):
fig, ax = plt.subplots()
fig.suptitle('iteration : ' + str(iteration), fontsize=11)
ax.set_axis_off()
tb = Table(ax, bbox=[0, 0, 1, 1])
nrows, ncols = image.shape
width, height = 1.0 / ncols, 1.0 / nrows
# Add cells
for (i, j), val in np.ndenumerate(image):
tb.add_cell(i, j, width, height, text=val,
loc='center', facecolor='white')
# Row and column labels...
for i in range(len(image)):
tb.add_cell(i, -1, width, height, text=i+1, loc='right',
edgecolor='none', facecolor='none')
tb.add_cell(-1, i, width, height/2, text=i+1, loc='center',
edgecolor='none', facecolor='none')
ax.add_table(tb)
if save_figure_name is not None:
plt.savefig(save_figure_name)
plt.show()
plt.close()
def compute_state_value(in_place=True, discount=1.0):
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE))
iteration = 0
while True:
if in_place:
state_values = new_state_values
else:
state_values = new_state_values.copy()
old_state_values = state_values.copy()
for i in range(WORLD_SIZE):
for j in range(WORLD_SIZE):
value = 0
for action in ACTIONS:
(next_i, next_j), reward = step([i, j], action)
value += ACTION_PROB * (reward + discount * state_values[next_i, next_j])
new_state_values[i, j] = value
max_delta_value = abs(old_state_values - new_state_values).max()
if max_delta_value < 1e-4:
break
iteration += 1
return new_state_values, iteration
def compute_state_value_progress(in_place=True, discount=1.0, max_iteration=1):
new_state_values = np.zeros((WORLD_SIZE, WORLD_SIZE))
iteration = 0
while iteration < max_iteration:
if in_place:
state_values = new_state_values
else:
state_values = new_state_values.copy()
old_state_values = state_values.copy()
for i in range(WORLD_SIZE):
for j in range(WORLD_SIZE):
value = 0
for action in ACTIONS:
(next_i, next_j), reward = step([i, j], action)
value += ACTION_PROB * (reward + discount * state_values[next_i, next_j])
new_state_values[i, j] = value
max_delta_value = abs(old_state_values - new_state_values).max()
if max_delta_value < 1e-4:
break
iteration += 1
return new_state_values, iteration
def figure_4_1():
values, sync_iteration = compute_state_value_progress(False, 1.0, 0)
draw_image_progress(np.round(values, decimals=2), sync_iteration)
values, sync_iteration = compute_state_value_progress(False, 1.0, 1)
draw_image_progress(np.round(values, decimals=2), sync_iteration)
values, sync_iteration = compute_state_value_progress(False, 1.0, 2)
draw_image_progress(np.round(values, decimals=2), sync_iteration)
values, sync_iteration = compute_state_value_progress(False, 1.0, 3)
draw_image_progress(np.round(values, decimals=2), sync_iteration)
values, sync_iteration = compute_state_value_progress(False, 1.0, 10)
draw_image_progress(np.round(values, decimals=2), sync_iteration)
# While the author suggests using in-place iterative policy evaluation,
# Figure 4.1 actually uses out-of-place version.
_, asycn_iteration = compute_state_value(in_place=True)
values, sync_iteration = compute_state_value(in_place=False)
draw_image_progress(np.round(values, decimals=2), sync_iteration, 'figure_4_1.png')
print('In-place: {} iterations'.format(asycn_iteration))
print('Synchronous: {} iterations'.format(sync_iteration))
if __name__ == '__main__':
figure_4_1()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment