Skip to content

Instantly share code, notes, and snippets.

@muety
Last active June 23, 2020 21:52
Show Gist options
  • Save muety/af0b8476ae4106ec098fea1dfe57f578 to your computer and use it in GitHub Desktop.
Save muety/af0b8476ae4106ec098fea1dfe57f578 to your computer and use it in GitHub Desktop.
# Inspired by https://medium.com/@tuzzer/cart-pole-balancing-with-q-learning-b54c6068d947
import gym
import numpy as np
import math
from collections import deque
class QCartPoleSolver():
def __init__(self, buckets=(1, 1, 6, 12,), n_episodes=1000, n_win_ticks=195, min_alpha=0.1, min_epsilon=0.1, gamma=1.0, ada_divisor=25, max_env_steps=None, quiet=False, monitor=False):
self.buckets = buckets # down-scaling feature space to discrete range
self.n_episodes = n_episodes # training episodes
self.n_win_ticks = n_win_ticks # average ticks over 100 episodes required for win
self.min_alpha = min_alpha # learning rate
self.min_epsilon = min_epsilon # exploration rate
self.gamma = gamma # discount factor
self.ada_divisor = ada_divisor # only for development purposes
self.quiet = quiet
self.env = gym.make('CartPole-v0')
if max_env_steps is not None: self.env._max_episode_steps = max_env_steps
if monitor: self.env = gym.wrappers.Monitor(self.env, 'tmp/cartpole-1', force=True) # record results for upload
self.Q = np.zeros(self.buckets + (self.env.action_space.n,))
def discretize(self, obs):
upper_bounds = [self.env.observation_space.high[0], 0.5, self.env.observation_space.high[2], math.radians(50)]
lower_bounds = [self.env.observation_space.low[0], -0.5, self.env.observation_space.low[2], -math.radians(50)]
ratios = [(obs[i] + abs(lower_bounds[i])) / (upper_bounds[i] - lower_bounds[i]) for i in range(len(obs))]
new_obs = [int(round((self.buckets[i] - 1) * ratios[i])) for i in range(len(obs))]
new_obs = [min(self.buckets[i] - 1, max(0, new_obs[i])) for i in range(len(obs))]
return tuple(new_obs)
def choose_action(self, state, epsilon):
return self.env.action_space.sample() if (np.random.random() <= epsilon) else np.argmax(self.Q[state])
def update_q(self, state_old, action, reward, state_new, alpha):
self.Q[state_old][action] += alpha * (reward + self.gamma * np.max(self.Q[state_new]) - self.Q[state_old][action])
def get_epsilon(self, t):
return max(self.min_epsilon, min(1, 1.0 - math.log10((t + 1) / self.ada_divisor)))
def get_alpha(self, t):
return max(self.min_alpha, min(1.0, 1.0 - math.log10((t + 1) / self.ada_divisor)))
def run(self):
scores = deque(maxlen=100)
for e in range(self.n_episodes):
current_state = self.discretize(self.env.reset())
alpha = self.get_alpha(e)
epsilon = self.get_epsilon(e)
done = False
i = 0
while not done:
# self.env.render()
action = self.choose_action(current_state, epsilon)
obs, reward, done, _ = self.env.step(action)
new_state = self.discretize(obs)
self.update_q(current_state, action, reward, new_state, alpha)
current_state = new_state
i += 1
scores.append(i)
mean_score = np.mean(scores)
if mean_score >= self.n_win_ticks and e >= 100:
if not self.quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100))
return e - 100
if e % 100 == 0 and not self.quiet:
print('[Episode {}] - Mean survival time over last 100 episodes was {} ticks.'.format(e, mean_score))
if not self.quiet: print('Did not solve after {} episodes 😞'.format(e))
return e
if __name__ == "__main__":
solver = QCartPoleSolver()
solver.run()
# gym.upload('tmp/cartpole-1', api_key='')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment