Created
July 24, 2019 14:02
-
-
Save PabRod/f093ebf33fab36dabcc22baa2d5efa9e to your computer and use it in GitHub Desktop.
Simple example of q-learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This gist reproduces the algorithm available at: | |
# http://mnemstudio.org/path-finding-q-learning-tutorial.htm | |
import numpy as np | |
import random | |
## Initialize q-table | |
Nstates = 6 | |
Nactions = 6 | |
Q = np.zeros((Nstates, Nactions)) | |
## Set the reward matrix | |
R = np.array([ | |
[-1, -1, -1, -1, 0, -1], | |
[-1, -1, -1, 0, -1, 100], | |
[-1, -1, -1, 0, -1, -1], | |
[-1, 0, 0, -1, 0, -1], | |
[0, -1, -1, 0, -1, 100], | |
[-1, 0, -1, -1, 0, 100] | |
]) | |
def random_indices(mat): | |
""" Returns a random position in the matrix | |
""" | |
nrows = len(mat) | |
ncols = len(mat[0]) | |
indices = [random.randint(0, nrows-1), random.randint(0, ncols-1)] | |
return indices | |
def accessible_states(current_state, R): | |
""" Returns the indices of the accessible future states | |
""" | |
is_accessible = (R[current_state, :] != -1) | |
return (np.where(is_accessible)) | |
# def optimal_action(current_state, R): | |
# """ Returns the optimal action | |
# """ | |
# return (np.argmax(R[current_state, :])) #TODO: return multiple coincidences | |
def updateQ(Q_current, R, goal_state, g=0.8, init = 'random'): | |
""" Simulates a single episode | |
""" | |
## Choose a starting point | |
if init == 'random': # Randomly | |
(state_current, unused) = random_indices(Q_current) | |
else: # Or provided as an input | |
state_current = init | |
Q_updated = Q_current | |
while state_current != goal_state: | |
possibilities = accessible_states(state_current, R)[0] # From the available transitions... | |
state_next = possibilities[random.randint(0, len(possibilities)-1)] # ... choose one randomly | |
## Update | |
Q_updated[state_current, state_next] = R[state_current, state_next] + g * np.max(Q_updated[state_next, :]) | |
state_current = state_next | |
return Q_updated | |
## Train! | |
episodes = 500 | |
for i in range(0, episodes): | |
Q = updateQ(Q, R, goal_state = 5) | |
print(Q/np.max(Q)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment