Skip to content

Instantly share code, notes, and snippets.

@fsndzomga
Created September 8, 2023 00:03
Show Gist options
  • Save fsndzomga/ef17fba4d4d78d7351a6cf8eeffcad00 to your computer and use it in GitHub Desktop.
Save fsndzomga/ef17fba4d4d78d7351a6cf8eeffcad00 to your computer and use it in GitHub Desktop.
Q learning basics
import numpy as np
import pandas as pd
import time
class QLearningAgent:
def __init__(self, n_states, actions, epsilon, alpha, gamma, max_episodes, fresh_time):
self.n_states = n_states
self.actions = actions
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.max_episodes = max_episodes
self.fresh_time = fresh_time
self.q_table = self.build_q_table()
def build_q_table(self):
table = pd.DataFrame(
np.zeros((self.n_states, len(self.actions))),
columns=self.actions,
)
return table
def choose_action(self, state):
state_actions = self.q_table.iloc[state, :]
if (np.random.uniform() > self.epsilon) or ((state_actions == 0).all()):
return np.random.choice(self.actions)
else:
return state_actions.idxmax()
def get_env_feedback(self, S, A):
if A == 'right':
if S == self.n_states - 2:
return 'terminal', 1
else:
return S + 1, 0
else:
if S == 0:
return S, 0
else:
return S - 1, 0
def update_env(self, S, episode, step_counter):
env_list = ['-'] * (self.n_states - 1) + ['T']
if S == 'terminal':
interaction = f'Episode {episode + 1}: total_steps = {step_counter}'
print(f'\r{interaction}', end='')
time.sleep(2)
print('\r ', end='')
else:
env_list[S] = 'o'
interaction = ''.join(env_list)
print(f'\r{interaction}', end='')
time.sleep(self.fresh_time)
def run(self):
for episode in range(self.max_episodes):
step_counter = 0
S = 0
is_terminated = False
self.update_env(S, episode, step_counter)
while not is_terminated:
A = self.choose_action(S)
S_, R = self.get_env_feedback(S, A)
q_predict = self.q_table.loc[S, A]
if S_ != 'terminal':
q_target = R + self.gamma * self.q_table.iloc[S_, :].max()
else:
q_target = R
is_terminated = True
self.q_table.loc[S, A] += self.alpha * (q_target - q_predict)
S = S_
self.update_env(S, episode, step_counter + 1)
step_counter += 1
print('\r\nQ-table:\n')
print(self.q_table)
if __name__ == "__main__":
np.random.seed(2)
agent = QLearningAgent(n_states=6,
actions=['left', 'right'],
epsilon=0.9,
alpha=0.1,
gamma=0.9,
max_episodes=13,
fresh_time=0.3)
agent.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment