Skip to content

Instantly share code, notes, and snippets.

@MCarlomagno
Created October 17, 2020 17:52
Show Gist options
  • Save MCarlomagno/0968343270e23962adc780ab90284b07 to your computer and use it in GitHub Desktop.
Save MCarlomagno/0968343270e23962adc780ab90284b07 to your computer and use it in GitHub Desktop.
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)
# agent.load("./save/cartpole-ddqn.h5")
done = False
batch_size = 32
for e in range(EPISODES):
state = env.reset()
state = np.reshape(state, [1, state_size])
for time in range(500):
# env.render()
action = agent.act(state)
next_state, reward, done, _ = env.step(action)
#reward = reward if not done else -10
x,x_dot,theta,theta_dot = next_state
r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
reward = r1 + r2
next_state = np.reshape(next_state, [1, state_size])
agent.memorize(state, action, reward, next_state, done)
state = next_state
if done:
agent.update_target_model()
print("episode: {}/{}, score: {}, e: {:.2}"
.format(e, EPISODES, time, agent.epsilon))
break
if len(agent.memory) > batch_size:
agent.replay(batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment