Skip to content

Instantly share code, notes, and snippets.

@geffy
Created November 23, 2016 18:58
Show Gist options
  • Save geffy/a2195fb0c03d18618fe03cf6604f54a6 to your computer and use it in GitHub Desktop.
Save geffy/a2195fb0c03d18618fe03cf6604f54a6 to your computer and use it in GitHub Desktop.
import gym
import numpy as np
env = gym.make('FrozenLake8x8-v0')
env.reset()
# find terminals
def find_terminals(mdp_raw):
terminals = set()
for src_state, node in mdp_raw.items():
for action, action_tuple in node.items():
for (prob, dst_state, reward, is_final) in action_tuple:
if is_final:
terminals.add(dst_state)
return terminals
def iterate_value_function(v_inp, gamma=0.98):
ret = np.zeros(64)
for sid in range(64):
temp_v = np.zeros(4)
# if sid not in terminals:
for action in range(4):
for (prob, dst_state, reward, is_final) in mdp_raw[sid][action]:
temp_v[action] += prob*(reward + gamma*v_inp[dst_state]*(not is_final))
ret[sid] = max(temp_v)
return ret
def build_greedy_policy(v_inp, gamma=0.98):
new_policy = np.zeros(64)
for state_id in range(64):
profits = np.zeros(4)
for action in range(4):
for (prob, dst_state, reward, is_final) in mdp_raw[state_id][action]:
profits[action] += prob*(reward + gamma*v[dst_state])
new_policy[state_id] = np.argmax(profits)
return new_policy
gamma = 0.999999
v = np.zeros(64)
# copy info about env
mdp_raw = env.P.copy()
terminals = find_terminals(mdp_raw)
# solve MDP
for _ in range(5000):
v = iterate_value_function(v, gamma)
print(np.array_str(v.reshape(8, 8), precision=2, suppress_small=True))
policy = build_greedy_policy(v, gamma).astype(np.int)
# run enviroment
# env.monitor.start('/tmp/frozenlake-vi', force=True)
cum_reward = 0
for t_rounds in range(100000):
env.reset()
observation = 0
for t in range(20000):
action = policy[observation]
observation, reward, done, info = env.step(action)
if done:
cum_reward += reward
break
# env.monitor.close()
print(cum_reward)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment