Created
November 24, 2016 16:47
-
-
Save geffy/b2d16d01cbca1ae9e13f11f678fa96fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Solving as MDP using Value Iteration Algorithm | |
import gym | |
import numpy as np | |
def iterate_value_function(v_inp, gamma, env): | |
ret = np.zeros(env.nS) | |
for sid in range(env.nS): | |
temp_v = np.zeros(env.nA) | |
for action in range(env.nA): | |
for (prob, dst_state, reward, is_final) in env.P[sid][action]: | |
temp_v[action] += prob*(reward + gamma*v_inp[dst_state]*(not is_final)) | |
ret[sid] = max(temp_v) | |
return ret | |
def build_greedy_policy(v_inp, gamma, env): | |
new_policy = np.zeros(env.nS) | |
for state_id in range(env.nS): | |
profits = np.zeros(env.nA) | |
for action in range(env.nA): | |
for (prob, dst_state, reward, is_final) in env.P[state_id][action]: | |
profits[action] += prob*(reward + gamma*v[dst_state]) | |
new_policy[state_id] = np.argmax(profits) | |
return new_policy | |
env = gym.make('Taxi-v1') | |
gamma = 0.999999 | |
cum_reward = 0 | |
n_rounds = 500 | |
env.monitor.start('/tmp/taxi-vi', force=True) | |
for t_rounds in range(n_rounds): | |
# init env and value function | |
observation = env.reset() | |
v = np.zeros(env.nS) | |
# solve MDP | |
for _ in range(100): | |
v_old = v.copy() | |
v = iterate_value_function(v, gamma, env) | |
if np.all(v == v_old): | |
break | |
policy = build_greedy_policy(v, gamma, env).astype(np.int) | |
# apply policy | |
for t in range(1000): | |
action = policy[observation] | |
observation, reward, done, info = env.step(action) | |
cum_reward += reward | |
if done: | |
break | |
if t_rounds % 50 == 0 and t_rounds > 0: | |
print(cum_reward * 1.0 / (t_rounds + 1)) | |
env.monitor.close() |
The code provided above is outdated. Replace env.nS with env.observation_space.n and env.nA with env.action_space.n. Use the code from the book.
Thanks, I'll give that a try.
I also found this through the textbook. Sorry I'm a year late. Just started reading the book.
Here's my updated code to work with the Gymnasium library:
https://gist.github.com/Taresin/a090274fbaf092ad649e4e32e22ecaf4
Hoping that this might help others.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Recently I purchased the book "Deep Reinforcement Learning" by Aske Plaat. He uses your taxi-vi.py code to illustrate how value iteration works. However, when I try to run it, I get the error messages 'TaxiEnv' object has no attribute nS and nA. I've been through the gym documentation, and various forums and can find no reference to these values. Can you tell me what they represent so I can put them into the code to get it working? Thanks.