Last active
June 23, 2020 21:52
-
-
Save muety/af0b8476ae4106ec098fea1dfe57f578 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by https://medium.com/@tuzzer/cart-pole-balancing-with-q-learning-b54c6068d947 | |
import gym | |
import numpy as np | |
import math | |
from collections import deque | |
class QCartPoleSolver(): | |
def __init__(self, buckets=(1, 1, 6, 12,), n_episodes=1000, n_win_ticks=195, min_alpha=0.1, min_epsilon=0.1, gamma=1.0, ada_divisor=25, max_env_steps=None, quiet=False, monitor=False): | |
self.buckets = buckets # down-scaling feature space to discrete range | |
self.n_episodes = n_episodes # training episodes | |
self.n_win_ticks = n_win_ticks # average ticks over 100 episodes required for win | |
self.min_alpha = min_alpha # learning rate | |
self.min_epsilon = min_epsilon # exploration rate | |
self.gamma = gamma # discount factor | |
self.ada_divisor = ada_divisor # only for development purposes | |
self.quiet = quiet | |
self.env = gym.make('CartPole-v0') | |
if max_env_steps is not None: self.env._max_episode_steps = max_env_steps | |
if monitor: self.env = gym.wrappers.Monitor(self.env, 'tmp/cartpole-1', force=True) # record results for upload | |
self.Q = np.zeros(self.buckets + (self.env.action_space.n,)) | |
def discretize(self, obs): | |
upper_bounds = [self.env.observation_space.high[0], 0.5, self.env.observation_space.high[2], math.radians(50)] | |
lower_bounds = [self.env.observation_space.low[0], -0.5, self.env.observation_space.low[2], -math.radians(50)] | |
ratios = [(obs[i] + abs(lower_bounds[i])) / (upper_bounds[i] - lower_bounds[i]) for i in range(len(obs))] | |
new_obs = [int(round((self.buckets[i] - 1) * ratios[i])) for i in range(len(obs))] | |
new_obs = [min(self.buckets[i] - 1, max(0, new_obs[i])) for i in range(len(obs))] | |
return tuple(new_obs) | |
def choose_action(self, state, epsilon): | |
return self.env.action_space.sample() if (np.random.random() <= epsilon) else np.argmax(self.Q[state]) | |
def update_q(self, state_old, action, reward, state_new, alpha): | |
self.Q[state_old][action] += alpha * (reward + self.gamma * np.max(self.Q[state_new]) - self.Q[state_old][action]) | |
def get_epsilon(self, t): | |
return max(self.min_epsilon, min(1, 1.0 - math.log10((t + 1) / self.ada_divisor))) | |
def get_alpha(self, t): | |
return max(self.min_alpha, min(1.0, 1.0 - math.log10((t + 1) / self.ada_divisor))) | |
def run(self): | |
scores = deque(maxlen=100) | |
for e in range(self.n_episodes): | |
current_state = self.discretize(self.env.reset()) | |
alpha = self.get_alpha(e) | |
epsilon = self.get_epsilon(e) | |
done = False | |
i = 0 | |
while not done: | |
# self.env.render() | |
action = self.choose_action(current_state, epsilon) | |
obs, reward, done, _ = self.env.step(action) | |
new_state = self.discretize(obs) | |
self.update_q(current_state, action, reward, new_state, alpha) | |
current_state = new_state | |
i += 1 | |
scores.append(i) | |
mean_score = np.mean(scores) | |
if mean_score >= self.n_win_ticks and e >= 100: | |
if not self.quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100)) | |
return e - 100 | |
if e % 100 == 0 and not self.quiet: | |
print('[Episode {}] - Mean survival time over last 100 episodes was {} ticks.'.format(e, mean_score)) | |
if not self.quiet: print('Did not solve after {} episodes 😞'.format(e)) | |
return e | |
if __name__ == "__main__": | |
solver = QCartPoleSolver() | |
solver.run() | |
# gym.upload('tmp/cartpole-1', api_key='') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment