Created
February 10, 2019 19:23
-
-
Save maymayw/66e00aca0a05190c570698f018592987 to your computer and use it in GitHub Desktop.
public secret
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from rl.lib.envs.gridworld import GridworldEnv | |
env = GridworldEnv() | |
def valiter(env, dr=1.0, theta = 0.0001): | |
def onestepahead(s, V): | |
q = np.zeros(env.nA) | |
for a in np.arange(env.nA): | |
[(prob, nexts, reward, done)] = env.P[s][a] | |
q[a] = prob * (reward + dr * V[nexts]) | |
return q | |
V = np.zeros(env.nS) | |
while True: | |
delta = 0 | |
for s in np.arange(env.nS): | |
v = max(onestepahead(s, V)) | |
delta = max(delta, np.abs(v - V[s])) | |
V[s] = v | |
if delta < theta: | |
break | |
policy = np.zeros([env.nS, env.nA]) | |
for s in np.arange(env.nS): | |
q = onestepahead(s, V) | |
bestaction = np.argmax(q) | |
policy[s, bestaction] = 1.0 | |
return policy, V | |
pol, V = valiter(env) | |
print(V.reshape(env.shape)) | |
actions = [u'\u2191', u'\u2192', u'\u2193', u'\u2190'] | |
print(np.array([[action for num, action in zip(one, actions) if num == 1.0] for one in pol]).reshape(env.shape)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment