Skip to content

Instantly share code, notes, and snippets.

@geffy
Created November 24, 2016 16:47
Show Gist options
  • Save geffy/b2d16d01cbca1ae9e13f11f678fa96fd to your computer and use it in GitHub Desktop.
Save geffy/b2d16d01cbca1ae9e13f11f678fa96fd to your computer and use it in GitHub Desktop.
# Solving as MDP using Value Iteration Algorithm
import gym
import numpy as np
def iterate_value_function(v_inp, gamma, env):
ret = np.zeros(env.nS)
for sid in range(env.nS):
temp_v = np.zeros(env.nA)
for action in range(env.nA):
for (prob, dst_state, reward, is_final) in env.P[sid][action]:
temp_v[action] += prob*(reward + gamma*v_inp[dst_state]*(not is_final))
ret[sid] = max(temp_v)
return ret
def build_greedy_policy(v_inp, gamma, env):
new_policy = np.zeros(env.nS)
for state_id in range(env.nS):
profits = np.zeros(env.nA)
for action in range(env.nA):
for (prob, dst_state, reward, is_final) in env.P[state_id][action]:
profits[action] += prob*(reward + gamma*v[dst_state])
new_policy[state_id] = np.argmax(profits)
return new_policy
env = gym.make('Taxi-v1')
gamma = 0.999999
cum_reward = 0
n_rounds = 500
env.monitor.start('/tmp/taxi-vi', force=True)
for t_rounds in range(n_rounds):
# init env and value function
observation = env.reset()
v = np.zeros(env.nS)
# solve MDP
for _ in range(100):
v_old = v.copy()
v = iterate_value_function(v, gamma, env)
if np.all(v == v_old):
break
policy = build_greedy_policy(v, gamma, env).astype(np.int)
# apply policy
for t in range(1000):
action = policy[observation]
observation, reward, done, info = env.step(action)
cum_reward += reward
if done:
break
if t_rounds % 50 == 0 and t_rounds > 0:
print(cum_reward * 1.0 / (t_rounds + 1))
env.monitor.close()
@dc2032
Copy link

dc2032 commented May 23, 2023

Thanks, I'll give that a try.

@Taresin
Copy link

Taresin commented Feb 26, 2024

I also found this through the textbook. Sorry I'm a year late. Just started reading the book.

Here's my updated code to work with the Gymnasium library:
https://gist.github.com/Taresin/a090274fbaf092ad649e4e32e22ecaf4

Hoping that this might help others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment