Skip to content

Instantly share code, notes, and snippets.

@artificialsoph
Last active September 4, 2018 13:30
Show Gist options
  • Save artificialsoph/155d718d310b09697cbf6e88fd036e34 to your computer and use it in GitHub Desktop.
Save artificialsoph/155d718d310b09697cbf6e88fd036e34 to your computer and use it in GitHub Desktop.
Reinforcement Learning Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These examples are modifications of the official [Keras-RL](https://github.com/keras-rl/keras-rl/tree/master/examples) examples"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"import gym\n",
"import pandas\n",
"%pylab inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import keras\n",
"from rl.agents.dqn import DQNAgent\n",
"from rl.policy import BoltzmannQPolicy\n",
"from rl.memory import SequentialMemory"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
]
}
],
"source": [
"ENV_NAME = 'CartPole-v0'\n",
"\n",
"# Get the environment and extract the number of actions.\n",
"env = gym.make(ENV_NAME)\n",
"np.random.seed(123)\n",
"env.seed(123)\n",
"nb_actions = env.action_space.n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.02078762, -0.01301236, -0.0209893 , -0.03935255])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten_1 (Flatten) (None, 4) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 32) 160 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 2) 66 \n",
"=================================================================\n",
"Total params: 226\n",
"Trainable params: 226\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-9cbad3bb6658>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=2000,\n\u001b[1;32m 16\u001b[0m target_model_update=1e-2, policy=policy, batch_size=512)\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mdqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkeras\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNadam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1e-3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'mae'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mdqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnb_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvisualize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_interval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/rl/agents/dqn.py\u001b[0m in \u001b[0;36mcompile\u001b[0;34m(self, optimizer, metrics)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;31m# We never train the target model, hence we can set the optimizer and loss arbitrarily.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 171\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclone_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcustom_model_objects\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 172\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sgd'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mse'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'sgd'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mse'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/rl/util.py\u001b[0m in \u001b[0;36mclone_model\u001b[0;34m(model, custom_objects)\u001b[0m\n\u001b[1;32m 13\u001b[0m }\n\u001b[1;32m 14\u001b[0m \u001b[0mclone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel_from_config\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcustom_objects\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcustom_objects\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mclone\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mclone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/keras/engine/network.py\u001b[0m in \u001b[0;36mget_weights\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 496\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0mweights\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 498\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_get_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 499\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mset_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36mbatch_get_value\u001b[0;34m(ops)\u001b[0m\n\u001b[1;32m 2388\u001b[0m \"\"\"\n\u001b[1;32m 2389\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2390\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mget_session\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mops\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2391\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2392\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36mget_session\u001b[0;34m()\u001b[0m\n\u001b[1;32m 182\u001b[0m config = tf.ConfigProto(intra_op_parallelism_threads=num_thread,\n\u001b[1;32m 183\u001b[0m allow_soft_placement=True)\n\u001b[0;32m--> 184\u001b[0;31m \u001b[0m_SESSION\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSession\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 185\u001b[0m \u001b[0msession\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_SESSION\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0m_MANUAL_VAR_INIT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, target, graph, config)\u001b[0m\n\u001b[1;32m 1492\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1493\u001b[0m \"\"\"\n\u001b[0;32m-> 1494\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mSession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgraph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1495\u001b[0m \u001b[0;31m# NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_default_graph_context_manager\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, target, graph, config)\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 626\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_NewSession\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_graph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 627\u001b[0m \u001b[0;31m# pylint: enable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# Next, we build a very simple model.\n",
"model = keras.models.Sequential([\n",
" keras.layers.Flatten(input_shape=(1,) + env.observation_space.shape),\n",
" keras.layers.Dense(32, activation=\"relu\"),\n",
"# keras.layers.Dense(16, activation=\"relu\"),\n",
"# keras.layers.Dense(16, activation=\"relu\"),\n",
" keras.layers.Dense(nb_actions, activation=\"linear\"),\n",
"])\n",
"print(model.summary())\n",
"\n",
"# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and\n",
"# even the metrics!\n",
"memory = SequentialMemory(limit=4000, window_length=1)\n",
"policy = BoltzmannQPolicy()\n",
"dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=2000,\n",
" target_model_update=1e-2, policy=policy, batch_size=512)\n",
"dqn.compile(keras.optimizers.Nadam(lr=1e-3), metrics=['mae'])\n",
"dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, log_interval=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"permute_2 (Permute) (None, 84, 84, 4) 0 \n",
"_________________________________________________________________\n",
"conv2d_6 (Conv2D) (None, 42, 42, 32) 1184 \n",
"_________________________________________________________________\n",
"conv2d_7 (Conv2D) (None, 21, 21, 32) 9248 \n",
"_________________________________________________________________\n",
"conv2d_8 (Conv2D) (None, 11, 11, 64) 18496 \n",
"_________________________________________________________________\n",
"conv2d_9 (Conv2D) (None, 6, 6, 64) 36928 \n",
"_________________________________________________________________\n",
"conv2d_10 (Conv2D) (None, 3, 3, 64) 36928 \n",
"_________________________________________________________________\n",
"flatten_3 (Flatten) (None, 576) 0 \n",
"_________________________________________________________________\n",
"dense_5 (Dense) (None, 256) 147712 \n",
"_________________________________________________________________\n",
"activation_3 (Activation) (None, 256) 0 \n",
"_________________________________________________________________\n",
"dense_6 (Dense) (None, 4) 1028 \n",
"_________________________________________________________________\n",
"activation_4 (Activation) (None, 4) 0 \n",
"=================================================================\n",
"Total params: 251,524\n",
"Trainable params: 251,524\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n",
"Training for 1750000 steps ...\n",
"Interval 1 (0 steps performed)\n",
"10000/10000 [==============================] - 35s 4ms/step - reward: 0.0060\n",
"54 episodes - episode_reward: 1.111 [0.000, 4.000] - ale.lives: 3.000\n",
"\n",
"Interval 2 (10000 steps performed)\n",
"10000/10000 [==============================] - 35s 4ms/step - reward: 0.0062\n",
"53 episodes - episode_reward: 1.170 [0.000, 4.000] - ale.lives: 2.917\n",
"\n",
"Interval 3 (20000 steps performed)\n",
"10000/10000 [==============================] - 36s 4ms/step - reward: 0.0059\n",
"56 episodes - episode_reward: 1.054 [0.000, 3.000] - ale.lives: 2.890\n",
"\n",
"Interval 4 (30000 steps performed)\n",
"10000/10000 [==============================] - 35s 4ms/step - reward: 0.0068\n",
"53 episodes - episode_reward: 1.283 [0.000, 5.000] - ale.lives: 2.823\n",
"\n",
"Interval 5 (40000 steps performed)\n",
"10000/10000 [==============================] - 35s 4ms/step - reward: 0.0061\n",
"53 episodes - episode_reward: 1.132 [0.000, 4.000] - ale.lives: 2.860\n",
"\n",
"Interval 6 (50000 steps performed)\n",
"10000/10000 [==============================] - 117s 12ms/step - reward: 0.0083\n",
"52 episodes - episode_reward: 1.615 [0.000, 7.000] - loss: 0.003 - mean_absolute_error: 0.009 - mean_q: 0.009 - mean_eps: 0.723 - ale.lives: 2.981\n",
"\n",
"Interval 7 (60000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0091\n",
"48 episodes - episode_reward: 1.854 [0.000, 8.000] - loss: 0.003 - mean_absolute_error: 0.013 - mean_q: 0.016 - mean_eps: 0.709 - ale.lives: 2.982\n",
"\n",
"Interval 8 (70000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0103\n",
"45 episodes - episode_reward: 2.333 [0.000, 8.000] - loss: 0.003 - mean_absolute_error: 0.022 - mean_q: 0.028 - mean_eps: 0.695 - ale.lives: 3.025\n",
"\n",
"Interval 9 (80000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0084\n",
"50 episodes - episode_reward: 1.680 [0.000, 7.000] - loss: 0.004 - mean_absolute_error: 0.024 - mean_q: 0.031 - mean_eps: 0.681 - ale.lives: 2.953\n",
"\n",
"Interval 10 (90000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0095\n",
"48 episodes - episode_reward: 1.979 [0.000, 5.000] - loss: 0.004 - mean_absolute_error: 0.032 - mean_q: 0.042 - mean_eps: 0.667 - ale.lives: 2.982\n",
"\n",
"Interval 11 (100000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0084\n",
"50 episodes - episode_reward: 1.680 [0.000, 5.000] - loss: 0.004 - mean_absolute_error: 0.038 - mean_q: 0.049 - mean_eps: 0.653 - ale.lives: 2.911\n",
"\n",
"Interval 12 (110000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0098\n",
"47 episodes - episode_reward: 2.085 [0.000, 7.000] - loss: 0.004 - mean_absolute_error: 0.042 - mean_q: 0.055 - mean_eps: 0.639 - ale.lives: 2.959\n",
"\n",
"Interval 13 (120000 steps performed)\n",
" 6045/10000 [=================>............] - ETA: 44s - reward: 0.0109"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0120\n",
"41 episodes - episode_reward: 2.951 [0.000, 9.000] - loss: 0.004 - mean_absolute_error: 0.063 - mean_q: 0.081 - mean_eps: 0.597 - ale.lives: 3.002\n",
"\n",
"Interval 16 (150000 steps performed)\n",
" 9756/10000 [============================>.] - ETA: 2s - reward: 0.0093"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0095\n",
"47 episodes - episode_reward: 2.000 [0.000, 6.000] - loss: 0.004 - mean_absolute_error: 0.083 - mean_q: 0.108 - mean_eps: 0.555 - ale.lives: 3.022\n",
"\n",
"Interval 19 (180000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0121\n",
"41 episodes - episode_reward: 2.927 [0.000, 9.000] - loss: 0.004 - mean_absolute_error: 0.090 - mean_q: 0.118 - mean_eps: 0.541 - ale.lives: 3.087\n",
"\n",
"Interval 20 (190000 steps performed)\n",
" 4309/10000 [===========>..................] - ETA: 1:05 - reward: 0.0102"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0115\n",
"41 episodes - episode_reward: 2.829 [0.000, 8.000] - loss: 0.005 - mean_absolute_error: 0.106 - mean_q: 0.139 - mean_eps: 0.499 - ale.lives: 3.024\n",
"\n",
"Interval 23 (220000 steps performed)\n",
" 9289/10000 [==========================>...] - ETA: 8s - reward: 0.0100"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0098\n",
"45 episodes - episode_reward: 2.200 [0.000, 5.000] - loss: 0.002 - mean_absolute_error: 0.126 - mean_q: 0.170 - mean_eps: 0.443 - ale.lives: 2.948\n",
"\n",
"Interval 27 (260000 steps performed)\n",
" 6790/10000 [===================>..........] - ETA: 36s - reward: 0.0081"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0087\n",
"49 episodes - episode_reward: 1.776 [0.000, 6.000] - loss: 0.001 - mean_absolute_error: 0.139 - mean_q: 0.186 - mean_eps: 0.401 - ale.lives: 2.991\n",
"\n",
"Interval 30 (290000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0083\n",
"46 episodes - episode_reward: 1.783 [0.000, 6.000] - loss: 0.001 - mean_absolute_error: 0.144 - mean_q: 0.193 - mean_eps: 0.387 - ale.lives: 2.845\n",
"\n",
"Interval 31 (300000 steps performed)\n",
" 4433/10000 [============>.................] - ETA: 1:03 - reward: 0.0077"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0074\n",
"49 episodes - episode_reward: 1.469 [0.000, 5.000] - loss: 0.000 - mean_absolute_error: 0.158 - mean_q: 0.211 - mean_eps: 0.345 - ale.lives: 2.984\n",
"\n",
"Interval 34 (330000 steps performed)\n",
" 9557/10000 [===========================>..] - ETA: 5s - reward: 0.0077"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0100\n",
"48 episodes - episode_reward: 2.104 [0.000, 6.000] - loss: 0.000 - mean_absolute_error: 0.175 - mean_q: 0.234 - mean_eps: 0.303 - ale.lives: 3.005\n",
"\n",
"Interval 37 (360000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0094\n",
"49 episodes - episode_reward: 1.918 [0.000, 6.000] - loss: 0.000 - mean_absolute_error: 0.183 - mean_q: 0.245 - mean_eps: 0.289 - ale.lives: 3.061\n",
"\n",
"Interval 38 (370000 steps performed)\n",
" 5940/10000 [================>.............] - ETA: 46s - reward: 0.0094"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0084\n",
"51 episodes - episode_reward: 1.647 [0.000, 5.000] - loss: 0.000 - mean_absolute_error: 0.197 - mean_q: 0.263 - mean_eps: 0.247 - ale.lives: 2.915\n",
"\n",
"Interval 41 (400000 steps performed)\n",
" 7633/10000 [=====================>........] - ETA: 26s - reward: 0.0092"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"IOPub message rate exceeded.\n",
"The notebook server will temporarily stop sending output\n",
"to the client in order to avoid crashing it.\n",
"To change this limit, set the config variable\n",
"`--NotebookApp.iopub_msg_rate_limit`.\n",
"\n",
"Current values:\n",
"NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
"NotebookApp.rate_limit_window=3.0 (secs)\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0089\n",
"46 episodes - episode_reward: 1.978 [0.000, 7.000] - loss: 0.001 - mean_absolute_error: 0.216 - mean_q: 0.291 - mean_eps: 0.205 - ale.lives: 3.056\n",
"\n",
"Interval 44 (430000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0111\n",
"45 episodes - episode_reward: 2.511 [0.000, 7.000] - loss: 0.001 - mean_absolute_error: 0.216 - mean_q: 0.292 - mean_eps: 0.191 - ale.lives: 2.952\n",
"\n",
"Interval 45 (440000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0099\n",
"46 episodes - episode_reward: 2.152 [0.000, 5.000] - loss: 0.001 - mean_absolute_error: 0.220 - mean_q: 0.298 - mean_eps: 0.177 - ale.lives: 3.023\n",
"\n",
"Interval 46 (450000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0134\n",
"37 episodes - episode_reward: 3.595 [1.000, 8.000] - loss: 0.001 - mean_absolute_error: 0.230 - mean_q: 0.313 - mean_eps: 0.163 - ale.lives: 3.048\n",
"\n",
"Interval 47 (460000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0143\n",
"36 episodes - episode_reward: 4.000 [2.000, 7.000] - loss: 0.001 - mean_absolute_error: 0.237 - mean_q: 0.323 - mean_eps: 0.149 - ale.lives: 2.981\n",
"\n",
"Interval 48 (470000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0158\n",
"35 episodes - episode_reward: 4.429 [2.000, 11.000] - loss: 0.001 - mean_absolute_error: 0.247 - mean_q: 0.338 - mean_eps: 0.135 - ale.lives: 3.034\n",
"\n",
"Interval 49 (480000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0150\n",
"32 episodes - episode_reward: 4.750 [3.000, 9.000] - loss: 0.001 - mean_absolute_error: 0.258 - mean_q: 0.353 - mean_eps: 0.121 - ale.lives: 3.139\n",
"\n",
"Interval 50 (490000 steps performed)\n",
"10000/10000 [==============================] - 114s 11ms/step - reward: 0.0167\n",
"25 episodes - episode_reward: 6.560 [4.000, 11.000] - loss: 0.001 - mean_absolute_error: 0.269 - mean_q: 0.370 - mean_eps: 0.107 - ale.lives: 3.153\n",
"\n",
"Interval 51 (500000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0147\n",
"23 episodes - episode_reward: 6.435 [4.000, 8.000] - loss: 0.001 - mean_absolute_error: 0.279 - mean_q: 0.384 - mean_eps: 0.100 - ale.lives: 3.036\n",
"\n",
"Interval 52 (510000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0150\n",
"22 episodes - episode_reward: 6.682 [4.000, 11.000] - loss: 0.001 - mean_absolute_error: 0.297 - mean_q: 0.409 - mean_eps: 0.100 - ale.lives: 3.041\n",
"\n",
"Interval 53 (520000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0179\n",
"20 episodes - episode_reward: 9.100 [4.000, 16.000] - loss: 0.001 - mean_absolute_error: 0.306 - mean_q: 0.421 - mean_eps: 0.100 - ale.lives: 3.033\n",
"\n",
"Interval 54 (530000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0175\n",
"16 episodes - episode_reward: 10.500 [6.000, 18.000] - loss: 0.001 - mean_absolute_error: 0.325 - mean_q: 0.448 - mean_eps: 0.100 - ale.lives: 2.921\n",
"\n",
"Interval 55 (540000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0202\n",
"18 episodes - episode_reward: 11.556 [8.000, 18.000] - loss: 0.001 - mean_absolute_error: 0.347 - mean_q: 0.476 - mean_eps: 0.100 - ale.lives: 3.301\n",
"\n",
"Interval 56 (550000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0207\n",
"16 episodes - episode_reward: 12.688 [8.000, 20.000] - loss: 0.001 - mean_absolute_error: 0.367 - mean_q: 0.503 - mean_eps: 0.100 - ale.lives: 2.914\n",
"\n",
"Interval 57 (560000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0200\n",
"21 episodes - episode_reward: 9.857 [3.000, 16.000] - loss: 0.002 - mean_absolute_error: 0.387 - mean_q: 0.544 - mean_eps: 0.100 - ale.lives: 3.247\n",
"\n",
"Interval 58 (570000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0221\n",
"17 episodes - episode_reward: 12.588 [6.000, 21.000] - loss: 0.002 - mean_absolute_error: 0.423 - mean_q: 0.597 - mean_eps: 0.100 - ale.lives: 3.051\n",
"\n",
"Interval 59 (580000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0208\n",
"15 episodes - episode_reward: 14.000 [10.000, 19.000] - loss: 0.002 - mean_absolute_error: 0.457 - mean_q: 0.631 - mean_eps: 0.100 - ale.lives: 3.119\n",
"\n",
"Interval 60 (590000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0226\n",
"15 episodes - episode_reward: 14.067 [7.000, 21.000] - loss: 0.002 - mean_absolute_error: 0.469 - mean_q: 0.650 - mean_eps: 0.100 - ale.lives: 3.155\n",
"\n",
"Interval 61 (600000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0226\n",
"15 episodes - episode_reward: 15.800 [11.000, 23.000] - loss: 0.002 - mean_absolute_error: 0.490 - mean_q: 0.673 - mean_eps: 0.100 - ale.lives: 3.087\n",
"\n",
"Interval 62 (610000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0224\n",
"14 episodes - episode_reward: 15.286 [7.000, 22.000] - loss: 0.002 - mean_absolute_error: 0.510 - mean_q: 0.699 - mean_eps: 0.100 - ale.lives: 3.251\n",
"\n",
"Interval 63 (620000 steps performed)\n",
"10000/10000 [==============================] - 112s 11ms/step - reward: 0.0220\n",
"15 episodes - episode_reward: 15.533 [8.000, 30.000] - loss: 0.002 - mean_absolute_error: 0.532 - mean_q: 0.727 - mean_eps: 0.100 - ale.lives: 3.112\n",
"\n",
"Interval 64 (630000 steps performed)\n",
"10000/10000 [==============================] - 113s 11ms/step - reward: 0.0205\n",
"17 episodes - episode_reward: 12.176 [6.000, 19.000] - loss: 0.002 - mean_absolute_error: 0.555 - mean_q: 0.759 - mean_eps: 0.100 - ale.lives: 3.001\n",
"\n",
"Interval 65 (640000 steps performed)\n",
" 7107/10000 [====================>.........] - ETA: 32s - reward: 0.0224"
]
}
],
"source": [
"from PIL import Image\n",
"import numpy as np\n",
"import gym\n",
"\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Activation, Flatten, Convolution2D, Permute\n",
"from keras.optimizers import Adam\n",
"import keras.backend as K\n",
"\n",
"from rl.agents.dqn import DQNAgent\n",
"from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy\n",
"from rl.memory import SequentialMemory\n",
"from rl.core import Processor\n",
"from rl.callbacks import FileLogger, ModelIntervalCheckpoint\n",
"\n",
"\n",
"INPUT_SHAPE = (84, 84)\n",
"WINDOW_LENGTH = 4\n",
"\n",
"\n",
"class AtariProcessor(Processor):\n",
" def process_observation(self, observation):\n",
" assert observation.ndim == 3 # (height, width, channel)\n",
" img = Image.fromarray(observation)\n",
" img = img.resize(INPUT_SHAPE).convert('L') # resize and convert to grayscale\n",
" processed_observation = np.array(img)\n",
" assert processed_observation.shape == INPUT_SHAPE\n",
" return processed_observation.astype('uint8') # saves storage in experience memory\n",
"\n",
" def process_state_batch(self, batch):\n",
" # We could perform this processing step in `process_observation`. In this case, however,\n",
" # we would need to store a `float32` array instead, which is 4x more memory intensive than\n",
" # an `uint8` array. This matters if we store 1M observations.\n",
" processed_batch = batch.astype('float32') / 255.\n",
" return processed_batch\n",
"\n",
" def process_reward(self, reward):\n",
" return np.clip(reward, -1., 1.)\n",
"\n",
"\n",
"\n",
"mode = \"train\"\n",
"env_name = 'BreakoutDeterministic-v4'\n",
"\n",
"# Get the environment and extract the number of actions.\n",
"env = gym.make(env_name)\n",
"np.random.seed(123)\n",
"env.seed(123)\n",
"nb_actions = env.action_space.n\n",
"\n",
"# Next, we build our model. We use the same model that was described by Mnih et al. (2015).\n",
"input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE\n",
"model = Sequential()\n",
"if K.image_dim_ordering() == 'tf':\n",
" # (width, height, channels)\n",
" model.add(Permute((2, 3, 1), input_shape=input_shape))\n",
"elif K.image_dim_ordering() == 'th':\n",
" # (channels, width, height)\n",
" model.add(Permute((1, 2, 3), input_shape=input_shape))\n",
"else:\n",
" raise RuntimeError('Unknown image_dim_ordering.')\n",
"model.add(Convolution2D(32, 3, strides=2, padding=\"same\", activation=\"relu\"))\n",
"model.add(Convolution2D(32, 3, strides=2, padding=\"same\", activation=\"relu\"))\n",
"model.add(Convolution2D(64, 3, strides=2, padding=\"same\", activation=\"relu\"))\n",
"model.add(Convolution2D(64, 3, strides=2, padding=\"same\", activation=\"relu\"))\n",
"model.add(Convolution2D(64, 3, strides=2, padding=\"same\", activation=\"relu\"))\n",
"model.add(Flatten())\n",
"model.add(Dense(256))\n",
"model.add(Activation('relu'))\n",
"model.add(Dense(nb_actions))\n",
"model.add(Activation('linear'))\n",
"print(model.summary())\n",
"\n",
"# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and\n",
"# even the metrics!\n",
"memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)\n",
"processor = AtariProcessor()\n",
"\n",
"# Select a policy. We use eps-greedy action selection, which means that a random action is selected\n",
"# with probability eps. We anneal eps from 1.0 to 0.1 over the course of 1M steps. This is done so that\n",
"# the agent initially explores the environment (high eps) and then gradually sticks to what it knows\n",
"# (low eps). We also set a dedicated eps value that is used during testing. Note that we set it to 0.05\n",
"# so that the agent still performs some random actions. This ensures that the agent cannot get stuck.\n",
"policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=.8, value_min=.1, value_test=.05,\n",
" nb_steps=500000)\n",
"\n",
"# The trade-off between exploration and exploitation is difficult and an on-going research topic.\n",
"# If you want, you can experiment with the parameters or use a different policy. Another popular one\n",
"# is Boltzmann-style exploration:\n",
"# policy = BoltzmannQPolicy(tau=1.)\n",
"# Feel free to give it a try!\n",
"\n",
"dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,\n",
" processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000,\n",
" train_interval=4, delta_clip=1.)\n",
"dqn.compile(Adam(lr=.00025), metrics=['mae'])\n",
"\n",
"if mode == 'train':\n",
" # Okay, now it's time to learn something! We capture the interrupt exception so that training\n",
" # can be prematurely aborted. Notice that you can the built-in Keras callbacks!\n",
" weights_filename = 'dqn_{}_weights.h5f'.format(env_name)\n",
" checkpoint_weights_filename = 'dqn_' + env_name + '_weights_{step}.h5f'\n",
" log_filename = 'dqn_{}_log.json'.format(env_name)\n",
" callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]\n",
" callbacks += [FileLogger(log_filename, interval=100)]\n",
" dqn.fit(env, callbacks=callbacks, nb_steps=1750000, log_interval=10000)\n",
"\n",
" # After training is done, we save the final weights one more time.\n",
" dqn.save_weights(weights_filename, overwrite=True)\n",
"\n",
" # Finally, evaluate our algorithm for 10 episodes.\n",
" dqn.test(env, nb_episodes=10, visualize=False)\n",
"elif args.mode == 'test':\n",
" weights_filename = 'dqn_{}_weights.h5f'.format(args.env_name)\n",
" if args.weights:\n",
" weights_filename = args.weights\n",
" dqn.load_weights(weights_filename)\n",
"dqn.test(env, nb_episodes=10, visualize=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:tf]",
"language": "python",
"name": "conda-env-tf-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment