This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| gamma = 0.8 # discount factor | |
| # initialize V | |
| new_V = np.zeros((nb_states, 1)) | |
| # loop until it reaches the optimal policy | |
| while True: | |
| old_V = new_V | |
| # V(s) <- \max_{a}( R(s, a) + gamma * \sum_{s'} P(s, a, s')*V(s') ) | |
| new_V = np.max(R_SA + gamma*np.squeeze(np.dot(P, old_V)), axis=1, keepdims=True) | |
| # if the changes are small, we consider that we have found V_* | |
| if np.max(new_V - old_V) < 0.01: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Defines the transition probabiliy kernel | |
| P = np.zeros((nb_states, nb_actions, nb_states)) | |
| # P[state, action, next_state] | |
| P[0,0,1] = P[1,0,2] = P[2,0,3] = 1 # go right | |
| P[2,1,1] = P[1,1,0] = P[0,1,0] = 1 # go left | |
| P[3,:,3] = 1 # State 3 is terminal |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Defines the reward R | |
| R_SA = np.zeros((nb_states, nb_actions)) | |
| R_SA[2, 0] = 1 # +1 reward when mario reaches state 3 (state 2, action = go right) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| nb_states = 4 | |
| nb_actions = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # pi is the greedy policy w.r.t V | |
| # Choose the action that maximize the future expected gain | |
| pi = np.argmax(R_SA + gamma*np.squeeze(np.dot(P, V)), axis=1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Initilaize the policy pi randomly | |
| # pi(s) = 0 or 1 | |
| new_pi = np.random.randint(low=0, high=2, size=(nb_states), dtype=int) | |
| # Loop until it reaches the optimal policy | |
| while True: | |
| old_pi = new_pi | |
| # Compute the value function of pi | |
| V = value(old_pi) | |
| # Update the policy with the greedy policy w.r.t. V | |
| # greedy_pi(s) = argmax_a[ R(s, a) \sum_{s'} P(s, a, s')V(s') ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def value(pi): | |
| """Returns a approximation of the value function of pi""" | |
| # P_pi(s,s') = P(s,pi(s),s') is the prop to move from s to s' following pi | |
| P_pi = np.array([[P[s][pi[s]][sp] for sp in states] for s in states]) | |
| # R_pi(s) = R(s,pi(s)) is the immidiate reward at every step, following pi | |
| R_pi = np.array([[R_SA[s, pi[s]]] for s in states]) | |
| # Compute this until it converges | |
| new_V = np.zeros((nb_states, 1)) | |
| while True: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import gym | |
| # create the environnement | |
| env = gym.make('FrozenLake-v0') | |
| # initial state | |
| current_state = env.reset() | |
| # loop until the game is over | |
| while True: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # the gain from this state | |
| # Gt = R_(t+1) + gamma*R_(t+2) + gamma^2*R_(t+3) + ... | |
| Gt = compute_gain(rewards, t, gamma=0.9) | |
| # the number of times this "state" has been encountered. | |
| N[state] += 1 | |
| # update the value function in the direction of the | |
| # V(St) = V(St) + 1/N(St) * (Gt - V(St)) | |
| V[state] += 1/N[state]*(Gt-V[state]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample_episodes(env, pi, nb_episodes): | |
| """Plays <nb_episodes> times the game, following the policy pi. | |
| Returns the list of (states, rewards)""" | |
| # initialize the list of episodes | |
| episodes = [] | |
| # plays the right number of time the game | |
| for _ in range(nb_episodes): | |
| # done is True when the game is over |
OlderNewer