Skip to content

Instantly share code, notes, and snippets.

@mbalunovic
Created September 11, 2017 06:23
Show Gist options
  • Save mbalunovic/fb7392e2c09b2c3895a354c3ad36497e to your computer and use it in GitHub Desktop.
Save mbalunovic/fb7392e2c09b2c3895a354c3ad36497e to your computer and use it in GitHub Desktop.
Q-Network, CartPole
import gym
import keras
import numpy as np
import random
from gym import wrappers
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from collections import deque
ACTIONS_DIM = 2
OBSERVATIONS_DIM = 4
MAX_ITERATIONS = 10**6
LEARNING_RATE = 0.001
NUM_EPOCHS = 50
GAMMA = 0.99
REPLAY_MEMORY_SIZE = 1000
NUM_EPISODES = 10000
TARGET_UPDATE_FREQ = 100
MINIBATCH_SIZE = 32
RANDOM_ACTION_DECAY = 0.99
INITIAL_RANDOM_ACTION = 1
class ReplayBuffer():
def __init__(self, max_size):
self.max_size = max_size
self.transitions = deque()
def add(self, observation, action, reward, observation2):
if len(self.transitions) > self.max_size:
self.transitions.popleft()
self.transitions.append((observation, action, reward, observation2))
def sample(self, count):
return random.sample(self.transitions, count)
def size(self):
return len(self.transitions)
def get_q(model, observation):
np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
return model.predict(np_obs)
def train(model, observations, targets):
# for i, observation in enumerate(observations):
# np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
# print "t: {}, p: {}".format(model.predict(np_obs),targets[i])
# exit(0)
np_obs = np.reshape(observations, [-1, OBSERVATIONS_DIM])
np_targets = np.reshape(targets, [-1, ACTIONS_DIM])
model.fit(np_obs, np_targets, epochs=1, verbose=0)
def predict(model, observation):
np_obs = np.reshape(observation, [-1, OBSERVATIONS_DIM])
return model.predict(np_obs)
def get_model():
model = Sequential()
model.add(Dense(16, input_shape=(OBSERVATIONS_DIM, ), activation='relu'))
model.add(Dense(16, input_shape=(OBSERVATIONS_DIM,), activation='relu'))
model.add(Dense(2, activation='linear'))
model.compile(
optimizer=Adam(lr=LEARNING_RATE),
loss='mse',
metrics=[],
)
return model
def update_action(action_model, target_model, sample_transitions):
random.shuffle(sample_transitions)
batch_observations = []
batch_targets = []
for sample_transition in sample_transitions:
old_observation, action, reward, observation = sample_transition
targets = np.reshape(get_q(action_model, old_observation), ACTIONS_DIM)
targets[action] = reward
if observation is not None:
predictions = predict(target_model, observation)
new_action = np.argmax(predictions)
targets[action] += GAMMA * predictions[0, new_action]
batch_observations.append(old_observation)
batch_targets.append(targets)
train(action_model, batch_observations, batch_targets)
def main():
steps_until_reset = TARGET_UPDATE_FREQ
random_action_probability = INITIAL_RANDOM_ACTION
# Initialize replay memory D to capacity N
replay = ReplayBuffer(REPLAY_MEMORY_SIZE)
# Initialize action-value model with random weights
action_model = get_model()
# Initialize target model with same weights
#target_model = get_model()
#target_model.set_weights(action_model.get_weights())
env = gym.make('CartPole-v0')
env = wrappers.Monitor(env, '/tmp/cartpole-experiment-1')
for episode in range(NUM_EPISODES):
observation = env.reset()
for iteration in range(MAX_ITERATIONS):
random_action_probability *= RANDOM_ACTION_DECAY
random_action_probability = max(random_action_probability, 0.1)
old_observation = observation
# if episode % 10 == 0:
# env.render()
if np.random.random() < random_action_probability:
action = np.random.choice(range(ACTIONS_DIM))
else:
q_values = get_q(action_model, observation)
action = np.argmax(q_values)
observation, reward, done, info = env.step(action)
if done:
print 'Episode {}, iterations: {}'.format(
episode,
iteration
)
# print action_model.get_weights()
# print target_model.get_weights()
#print 'Game finished after {} iterations'.format(iteration)
reward = -200
replay.add(old_observation, action, reward, None)
break
replay.add(old_observation, action, reward, observation)
if replay.size() >= MINIBATCH_SIZE:
sample_transitions = replay.sample(MINIBATCH_SIZE)
update_action(action_model, action_model, sample_transitions)
steps_until_reset -= 1
# if steps_until_reset == 0:
# target_model.set_weights(action_model.get_weights())
# steps_until_reset = TARGET_UPDATE_FREQ
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment