Created
March 14, 2021 02:20
-
-
Save gandroz/faba2b99d57b35dbad404bb6a3c2447b to your computer and use it in GitHub Desktop.
Q Learning implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gym | |
# Define the environment | |
env = gym.make("Taxi-v2").env | |
# Initialize the q-table with zero values | |
q_table = np.zeros([env.observation_space.n, env.action_space.n]) | |
# Hyperparameters | |
alpha = 0.1 # learning-rate | |
gamma = 0.7 # discount-factor | |
epsilon = 0.1 # explor vs exploit | |
# Random generator | |
rng =np.random.default_rng() | |
# Perform 100,000 episodes | |
for i in range(100_000): | |
# Reset the environment | |
state = env.reset() | |
done = False | |
# Loop as long as the game is not over, i.e. done is not True | |
while not done: | |
if rng.random() < epsilon: | |
action = env.action_space.sample() # Explore the action space | |
else: | |
action = np.argmax(q_table[state]) # Exploit learned values | |
# Apply the action and see what happens | |
next_state, reward, done, info = env.step(action) | |
current_value = q_table[state, action] # current Q-value for the state/action couple | |
next_max = np.max(q_table[next_state]) # next best Q-value | |
# Compute the new Q-value with the Bellman equation | |
q_table[state, action] = (1 - alpha) * current_value + alpha * (reward + gamma * next_max) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment