Created
July 9, 2017 02:19
-
-
Save malzantot/ed173b66e76a05e9c8eeec60dd476948 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Solving FrozenLake8x8 environment using Policy iteration. | |
Author : Moustafa Alzantot ([email protected]) | |
""" | |
import numpy as np | |
import gym | |
from gym import wrappers | |
def run_episode(env, policy, gamma = 1.0, render = False): | |
""" Runs an episode and return the total reward """ | |
obs = env.reset() | |
total_reward = 0 | |
step_idx = 0 | |
while True: | |
if render: | |
env.render() | |
obs, reward, done , _ = env.step(int(policy[obs])) | |
total_reward += (gamma ** step_idx * reward) | |
step_idx += 1 | |
if done: | |
break | |
return total_reward | |
def evaluate_policy(env, policy, gamma = 1.0, n = 100): | |
scores = [run_episode(env, policy, gamma, False) for _ in range(n)] | |
return np.mean(scores) | |
def extract_policy(v, gamma = 1.0): | |
""" Extract the policy given a value-function """ | |
policy = np.zeros(env.nS) | |
for s in range(env.nS): | |
q_sa = np.zeros(env.nA) | |
for a in range(env.nA): | |
q_sa[a] = sum([p * (r + gamma * v[s_]) for p, s_, r, _ in env.P[s][a]]) | |
policy[s] = np.argmax(q_sa) | |
return policy | |
def compute_policy_v(env, policy, gamma=1.0): | |
""" Iteratively evaluate the value-function under policy. | |
Alternatively, we could formulate a set of linear equations in iterms of v[s] | |
and solve them to find the value function. | |
""" | |
v = np.zeros(env.nS) | |
eps = 1e-10 | |
while True: | |
prev_v = np.copy(v) | |
for s in range(env.nS): | |
policy_a = policy[s] | |
v[s] = sum([p * (r + gamma * prev_v[s_]) for p, s_, r, _ in env.P[s][policy_a]]) | |
if (np.sum((np.fabs(prev_v - v))) <= eps): | |
# value converged | |
break | |
return v | |
def policy_iteration(env, gamma = 1.0): | |
""" Policy-Iteration algorithm """ | |
policy = np.random.choice(env.nA, size=(env.nS)) # initialize a random policy | |
max_iterations = 200000 | |
gamma = 1.0 | |
for i in range(max_iterations): | |
old_policy_v = compute_policy_v(env, policy, gamma) | |
new_policy = extract_policy(old_policy_v, gamma) | |
if (np.all(policy == new_policy)): | |
print ('Policy-Iteration converged at step %d.' %(i+1)) | |
break | |
policy = new_policy | |
return policy | |
if __name__ == '__main__': | |
env_name = 'FrozenLake8x8-v0' | |
env = gym.make(env_name) | |
optimal_policy = policy_iteration(env, gamma = 1.0) | |
scores = evaluate_policy(env, optimal_policy, gamma = 1.0) | |
print('Average scores = ', np.mean(scores)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment