Created
November 27, 2018 18:41
-
-
Save marekgalovic/175c2ae1eddae860743460369bd97799 to your computer and use it in GitHub Desktop.
CartPole PG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import gym | |
import numpy as np | |
import tensorflow as tf | |
EPOCHS = 250 | |
NUM_GAMES = 100 | |
BATCH_SIZE = 128 | |
MAX_UPDATES = 10 | |
GAMMA = 0.99 | |
N_PAST_STATES = 4 | |
STATE_DIM = 4 | |
ACTION_DIM = 2 | |
class Agent: | |
def __init__(self, hidden_size=64, name='agent'): | |
self._hidden_size = int(hidden_size) | |
self.graph = tf.Graph() | |
with self.graph.as_default(), \ | |
tf.name_scope(name): | |
self._placeholders() | |
self._build() | |
self._loss() | |
self.saver = tf.train.Saver() | |
def _placeholders(self): | |
with tf.name_scope('placeholders'): | |
self.state_ph = tf.placeholder(tf.float32, [None, N_PAST_STATES * STATE_DIM]) | |
self.action_ph = tf.placeholder(tf.int32, [None]) | |
self.reward_ph = tf.placeholder(tf.float32, [None]) | |
def _build(self): | |
with tf.name_scope('model'): | |
h = tf.layers.dense(self.state_ph, self._hidden_size, activation=tf.nn.relu) | |
self._action_logits = tf.layers.dense(h, ACTION_DIM) | |
self.action_p = tf.nn.softmax(self._action_logits) | |
def _loss(self): | |
gather_indices = tf.stack([ | |
tf.range(tf.shape(self.reward_ph)[0]), | |
self.action_ph | |
], -1) | |
actions = tf.gather_nd(self.action_p, gather_indices) | |
self.loss = tf.reduce_mean(-tf.log(actions) * tf.clip_by_value(self.reward_ph, -10, 10)) | |
self.train_op = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(self.loss) | |
def discount_rewards(rewards): | |
discounted, running_sum = np.zeros_like(rewards), 0 | |
for i in reversed(range(0, len(rewards))): | |
running_sum = GAMMA * running_sum + rewards[i] | |
discounted[i] = running_sum | |
discounted -= np.mean(discounted) | |
discounted /= np.std(discounted) | |
return discounted | |
def collect_samples(sess, agent, num_games=NUM_GAMES): | |
env = gym.make('CartPole-v1') | |
samples_buffer, rewards = [], [] | |
for _ in range(num_games): | |
state = env.reset() | |
game_samples, states, total_reward = [], [state], 0 | |
while True: | |
state = np.concatenate(states[-N_PAST_STATES:]) | |
state = np.pad(state, [(N_PAST_STATES * env.observation_space.shape[0] - len(state), 0)], 'constant', constant_values=0) | |
action_p = sess.run(agent.action_p, feed_dict = { | |
agent.state_ph: [state] | |
}) | |
action_p = action_p[0] | |
action = np.random.choice(2, p=action_p) | |
new_state, reward, done, _ = env.step(action) | |
game_samples.append((state, action)) | |
states.append(new_state) | |
total_reward += reward | |
if done or total_reward >= 300: | |
break | |
game_discounted_rewards = discount_rewards([1.0] * len(game_samples)) | |
for i, (state, action) in enumerate(game_samples): | |
game_samples[i] = (state, action, game_discounted_rewards[i]) | |
rewards.append(total_reward) | |
samples_buffer.extend(game_samples) | |
return samples_buffer, np.mean(rewards) | |
def run(sess, agent): | |
env = gym.make('CartPole-v1') | |
state = env.reset() | |
states, total_reward = [state], 0 | |
while True: | |
env.render() | |
state = np.concatenate(states[-N_PAST_STATES:]) | |
state = np.pad(state, [(N_PAST_STATES * env.observation_space.shape[0] - len(state), 0)], 'constant', constant_values=0) | |
action_p = sess.run(agent.action_p, feed_dict = { | |
agent.state_ph: [state] | |
}) | |
action = np.argmax(action_p[0]) | |
new_state, reward, done, _ = env.step(action) | |
states.append(new_state) | |
total_reward += reward | |
if done: | |
print('Total reward: %d' % total_reward) | |
break | |
def main(): | |
env = gym.make('CartPole-v1') | |
agent = Agent() | |
with tf.Session(graph=agent.graph) as sess: | |
sess.run(tf.global_variables_initializer()) | |
num_games=NUM_GAMES | |
for e in range(EPOCHS): | |
samples, mean_reward = collect_samples(sess, agent, num_games=num_games) | |
print("Average reward: %.2f" % mean_reward) | |
random.shuffle(samples) | |
if mean_reward > 50 and num_games > NUM_GAMES / 2.0: | |
num_games = int(num_games / 2) | |
if mean_reward > 100 and num_games > NUM_GAMES / 4.0: | |
num_games = int(num_games / 2) | |
if mean_reward > 150 and num_games > NUM_GAMES / 8.0: | |
num_games = int(num_games / 2) | |
if mean_reward > 250: | |
break | |
for i in range(min(int(len(samples) / BATCH_SIZE), MAX_UPDATES)): | |
batch_samples = samples[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE] | |
states, actions, rewards = zip(*batch_samples) | |
_, loss = sess.run([agent.train_op, agent.loss], feed_dict = { | |
agent.state_ph: states, | |
agent.action_ph: actions, | |
agent.reward_ph: rewards, | |
}) | |
if np.isnan(loss): | |
raise ValueError('NaN loss') | |
if e % 25 == 0 and e > 0: | |
agent.saver.save(sess, './checkpoints/agent-%d.ckpt' % (e)) | |
agent.saver.save(sess, './checkpoints/agent-final.ckpt') | |
run(sess, agent) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment