Skip to content

Instantly share code, notes, and snippets.

@SnowyPainter
Last active February 6, 2024 14:14
Show Gist options
  • Save SnowyPainter/28edba6390856e57b948334239c28b50 to your computer and use it in GitHub Desktop.
Save SnowyPainter/28edba6390856e57b948334239c28b50 to your computer and use it in GitHub Desktop.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import RMSprop
import numpy as np
import random
class QLearningAgent:
def __init__(self, env, learning_rate=0.1, discount_factor=0.99, epsilon=0.1):
self.env = env
self.step_size = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
self.epsilon_min = 0.1
self.model = self._build_model()
def _build_model(self):
model = Sequential()
model.add(Dense(48, input_shape=(self.env.lags, self.env.n_features), activation='relu'))
model.add(Dropout(0.2, seed=100))
model.add(Dense(24, activation='relu'))
model.add(Dropout(0.2, seed=100))
model.add(Dense(2, activation='linear'))
model.compile(loss='mse', optimizer=RMSprop(learning_rate=0.001))
return model
def get_action(self, state):
if random.random() <= self.epsilon:
return self.env.action_space.sample()
else:
q_table = self.model.predict(np.array([state]), verbose=0)[0,0]
return np.argmax(q_table)
def update_q_values(self, state, action, reward, next_state):
self.q_table = self.model.predict(np.array([state]), verbose=0)
next_q_table = self.model.predict(np.array([next_state]), verbose=0)
best_next_action = np.argmax(next_q_table)
td_target = reward + self.discount_factor * next_q_table.flatten()[best_next_action]
td_error = td_target - self.q_table.flatten()[action]
self.q_table[np.unravel_index(action, self.q_table.shape)] += self.step_size * td_error
self.model.fit(np.array([state]), self.q_table, verbose=0)
def learn(self, episodes, batch_size):
batch_states, batch_q_tables = [], []
for episode in range(episodes):
state = self.env.reset()
done = False
while not done:
action = self.get_action(state)
next_state, reward, done, _ = self.env.step(action)
self.update_q_values(state, action, reward, next_state)
batch_states.append([next_state])
batch_q_tables.append(self.model.predict(np.array([state]), verbose=0))
if len(batch_states) >= batch_size:
self.model.fit(np.array(batch_states), np.array(batch_q_tables), verbose=0)
batch_states, batch_q_tables = [], []
if self.epsilon > self.epsilon_min:
self.epsilon *= 0.99
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment