Skip to content

Instantly share code, notes, and snippets.

@kkweon
Last active April 9, 2017 04:21
Show Gist options
  • Save kkweon/04b6eee2c7540acd95f6b0f2314f9469 to your computer and use it in GitHub Desktop.
Save kkweon/04b6eee2c7540acd95f6b0f2314f9469 to your computer and use it in GitHub Desktop.
Basic Actor Critic
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# What is an actor critic network?\n",
"\n",
"- So far we have learned a single output network that produces Q-values(Value Iteration) or an action policy (Policy Iteration)\n",
"- What if we can use both value functions and policy functions? That's how actor-critic methods were developed. It turns out if we use both, we can learn more complex systems. In this notebook, we will a simple policy gradient actor-critic methods\n",
"\n",
"# Structure of Actor Critic Networks\n",
"- There are two networks: an actor network and a critic network\n",
"![actor-critic](https://raw.githubusercontent.com/hunkim/ReinforcementZeroToAll/6cc4711f61ff585b5e101e43cd31bc909a32afad/assets/actor_critic.png)\n",
"- Actor network:\n",
" * This network chooses an action!\n",
" * It takes an input of game state and produce outputs an action policy (as in policy-gradient)\n",
"\n",
"- Critic network:\n",
" * This network is simply a value network\n",
" * It takes the same input as the actor network and produces a current state value"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# From pervious lectures\n",
"\n",
"* We used policy gradient methods that is to find a policy $\\pi$ that\n",
"$$\n",
"\\text{maximize } E\\left[\\ R\\ \\mid\\ \\pi\\ \\right]\n",
"$$\n",
"\n",
"$$\\text{where }R = r_0 + r_1 + \\dots + r_{\\tau - 1}$$\n",
"\n",
"* We use an gradient estimator that is\n",
"$$\n",
"\\hat{g} = \\nabla_\\theta \\log \\pi \\left(a_t \\mid s_t; \\theta \\right) \\cdot R_t \n",
"$$\n",
"\n",
"$$\\text{where }R_t = \\sum_{t'=t}^{T-1} \\left(discount\\ rate\\right)^{t'-t} \\cdot (reward)_t$$"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"* The above gradient estimator simply means we boost the probability of an action that returns high rewards\n",
"\n",
"# Problems\n",
"\n",
"* The above method is however not stable because a step size of the gradients can be very large and once we overshoot, our agent will collect tracjectories based on a bad policy"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Solution\n",
"* In order to solve high variance problems, we will use $A_t$ instead of $R_t$ and this is called an advantage function\n",
"* What is an advantage function? We know a Q function and a Value function. \n",
" * The $Q$ maps a state $s$ to an action $a$ value which is how good action $a$ is\n",
" * The $V$ maps a state $s$ to a value that shows how good an input state $s$ is\n",
" \n",
"* Therefore, we can write two functions as following:\n",
"$$ Q(s, a) = V(s) + A(a) $$\n",
"\n",
"* Therefore, \n",
"\n",
"$$ A(a) = Q(s, a) - V(s) $$\n",
"\n",
"* That's the definition of an advatage function. We are trying to find how good action $a$ is by subtracting a value function\n",
"\n",
"* Hence, we need to change the gradient estimator $\\hat{g}$ to the following\n",
"$$\\hat{g} = \\nabla_\\theta \\log \\pi(a_t | s_t; \\theta) \\cdot A_t $$\n",
"\n",
"where\n",
"\n",
"\\begin{align*}\n",
"A_t &= Q(s_t, a') - V(s_t) \\\\\n",
" &= R_{t} - V(s_t)\n",
"\\end{align*}\n",
"\n",
"# Notes\n",
"- Its performance is still not great because it has a few flaws\n",
" - We have to learn $V(s)$ first and learning $V(s)$ can be very difficult (requires careful reward enginneering)\n",
" - Every trajectories is highly correlated\n",
"- In order to deal with these problems, we will later discuss various methods such as TRPO(Trust Region Policy Optimization) or A3C(Asynchronous Actor Critic Networks)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"import gym\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"slim = tf.contrib.slim"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"class ActorCriticNetwork:\n",
" \"\"\" Actor Critic Network\n",
" \n",
" - 3 placeholders for policy\n",
" - S : state (shared)\n",
" - A : action one hot\n",
" - ADV : advantage value\n",
" \n",
" - 2 placeholders for value\n",
" - S : state (shared)\n",
" - R : reward\n",
" \n",
" - 2 outputs\n",
" - P : action policy, p(a | s)\n",
" - V : V(s)\n",
" \n",
" Examples\n",
" ----------\n",
" >>> input_shape = [None, 4]\n",
" >>> action_n = 2\n",
" >>> hidden_dims = [32, 32]\n",
" >>> ac_network = ActorCriticNetwork(input_shape, action_n, hidden_dims)\n",
" \"\"\"\n",
" def __init__(self, input_shape, action_n, hidden_dims):\n",
" # Policy Input\n",
" self.S = tf.placeholder(tf.float32, shape=input_shape, name=\"state_input\")\n",
" self.A = tf.placeholder(tf.float32, shape=[None, action_n], name=\"action_one_hot_input\")\n",
" self.ADV = tf.placeholder(tf.float32, shape=[None], name=\"advantage_input\")\n",
" \n",
" # Value Input\n",
" self.R = tf.placeholder(tf.float32, shape=[None], name=\"reward_input\")\n",
" \n",
" self._create_network(hidden_dims, action_n)\n",
" \n",
" def _create_network(self, hidden_dims, action_n):\n",
" net = self.S\n",
" \n",
" for i, h_dim in enumerate(hidden_dims):\n",
" net = slim.fully_connected(net, h_dim, activation_fn=None, scope=f\"fc-{i}\")\n",
" net = tf.nn.relu(net)\n",
" \n",
" # Policy shape: [None, action_n]\n",
" self.P = slim.fully_connected(net, action_n, activation_fn=tf.nn.softmax, scope=\"policy_output\")\n",
"\n",
" # Value shape: [None, 1] -> [None]\n",
" _V = slim.fully_connected(net, 1, activation_fn=None, scope=\"value_output\")\n",
" self.V = tf.squeeze(_V)\n",
" \n",
" self._create_op()\n",
" \n",
" def _create_op(self):\n",
" # output shape: [None]\n",
" policy_gain = tf.reduce_sum(self.P * self.A, 1)\n",
"\n",
" # output shape: [None]\n",
" policy_gain = tf.log(policy_gain) * self.ADV\n",
" policy_gain = tf.reduce_sum(policy_gain, name=\"policy_gain\")\n",
"\n",
" entropy = - tf.reduce_sum(self.P * tf.log(self.P), 1)\n",
" entropy = tf.reduce_mean(entropy)\n",
" \n",
" value_loss = tf.losses.mean_squared_error(self.V, self.R, scope=\"value_loss\")\n",
" \n",
" # Becareful negative sign because we only can minimize\n",
" # we want to maximize policy gain and entropy (for exploration)\n",
" self.loss = - policy_gain + value_loss - entropy * 0.01\n",
" self.optimizer = tf.train.AdamOptimizer()\n",
" self.train_op = self.optimizer.minimize(self.loss)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"class Agent:\n",
" \"\"\" Agent class \"\"\"\n",
" \n",
" def __init__(self, env, network):\n",
" \"\"\" Constructor\n",
" \n",
" Parameters\n",
" ----------\n",
" env\n",
" Open ai gym environment \n",
" network\n",
" Actor Critic Network \n",
" \"\"\"\n",
" self.env = env\n",
" self.model = network\n",
" self.sess = tf.get_default_session()\n",
" self.action_n = env.action_space.n\n",
" \n",
" \n",
" def choose_an_action(self, state):\n",
" \"\"\" Returns an action (int) \"\"\"\n",
" \n",
" feed = {\n",
" self.model.S: state\n",
" }\n",
" \n",
" action_prob = self.sess.run(self.model.P, feed_dict=feed)[0]\n",
" \n",
" return np.random.choice(np.arange(self.action_n), p=action_prob)\n",
" \n",
" def train(self, S, A, R):\n",
" \"\"\" Train the actor critic networks\n",
" \n",
" (1) Compute discounted rewards R\n",
" (2) Compute advantage values A = R - V\n",
" (3) Perform gradients updates\n",
" \n",
" \"\"\"\n",
" \n",
" def discount_rewards(r, gamma=0.99):\n",
" \"\"\" take 1D float array of rewards and compute discounted reward \"\"\"\n",
" discounted_r = np.zeros_like(r, dtype=np.float32)\n",
" running_add = 0\n",
" \n",
" for t in reversed(range(len(r))):\n",
" running_add = running_add * gamma + r[t]\n",
" discounted_r[t] = running_add\n",
"\n",
" return discounted_r\n",
"\n",
" # 1. Get discounted `R`s\n",
" R = discount_rewards(R)\n",
" \n",
" # 2. Get `V`s\n",
" feed = {\n",
" self.model.S: S\n",
" }\n",
" V = self.sess.run(self.model.V, feed_dict=feed)\n",
" \n",
" # 3. Get Advantage values, A = R - V\n",
" ADV = R - V \n",
" ADV = (ADV - np.mean(ADV)) / (np.std(ADV) + 1e-8)\n",
" \n",
" # 4. Perform gradient descents\n",
" feed = {\n",
" self.model.S: S,\n",
" self.model.A: A,\n",
" self.model.ADV: ADV,\n",
" self.model.R: R\n",
" }\n",
"\n",
" self.sess.run(self.model.train_op, feed_dict=feed) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:41,639] Making new env: CartPole-v0\n",
"[2017-04-08 21:10:41,643] Clearing 26 monitor files from previous run (because force=True was provided)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"input_shape: [None, 4], action_n: 2\n"
]
}
],
"source": [
"# Tensorflow Reset\n",
"tf.reset_default_graph()\n",
"sess = tf.InteractiveSession()\n",
"\n",
"# Gym Environment Setup\n",
"env_name = \"CartPole-v0\"\n",
"env = gym.make(env_name)\n",
"env = gym.wrappers.Monitor(env, \"./gym-results/\", force=True)\n",
"\n",
"# Global parameters\n",
"input_shape = [None, env.observation_space.shape[0]]\n",
"action_n = env.action_space.n\n",
"\n",
"print(f\"input_shape: {input_shape}, action_n: {action_n}\")\n",
"\n",
"# Define A2C(Actor-Critic) and Agent\n",
"ac_network = ActorCriticNetwork(input_shape, action_n, [32, 32])\n",
"agent = Agent(env, ac_network)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"def preprocess_state(state_list):\n",
" \"\"\" Preprocess a state list\n",
" \n",
" Currently it's only used to reshape the value\n",
" When a single state is given, its shape is 1-d array,\n",
" which needs to be reshaped in 2-d array\n",
" \"\"\"\n",
" return np.reshape(state_list, [-1, *input_shape[1:]])\n",
"\n",
"def preprocess_action(action_list, n_actions):\n",
" \"\"\"Action -> 1-hot \"\"\"\n",
" N = len(action_list)\n",
" one_hot = np.zeros(shape=(N, n_actions))\n",
" one_hot[np.arange(N), action_list] = 1\n",
" \n",
" return one_hot\n",
"\n",
"# Test codes\n",
"tmp = np.zeros((32, *input_shape[1:]))\n",
"np.testing.assert_almost_equal(preprocess_state(tmp), np.zeros([32, *input_shape[1:]]))\n",
"tmp = np.zeros(*input_shape[1:])\n",
"np.testing.assert_almost_equal(preprocess_state(tmp), np.zeros([1, *input_shape[1:]]))\n",
"tmp = [0, 1]\n",
"np.testing.assert_almost_equal(preprocess_action(tmp, 2), np.eye(2))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:42,119] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000000.mp4\n",
"[2017-04-08 21:10:43,292] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000001.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 0] 16\r\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:43,998] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000008.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 1] 16\r",
"[Episode- 2] 22\r",
"[Episode- 3] 10\r",
"[Episode- 4] 45\r",
"[Episode- 5] 17\r",
"[Episode- 6] 16\r",
"[Episode- 7] 13\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:45,072] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000027.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 46] 22\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:46,212] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000064.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 107] 10\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:47,241] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000125.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 209] 19\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:48,925] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000216.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 337] 60\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:10:51,951] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000343.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 507] 31\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:11:00,967] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000512.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 722] 104\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:11:09,900] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video000729.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 993] 130\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:11:25,444] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video001000.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 1000] 200\n",
"[Episode- 1996] 26\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:12:26,066] Starting new video recorder writing to /home/kkweon/github/ReinforcementZeroToAll/gym-results/openaigym.video.0.12761.video002000.mp4\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode- 2000] 146\n",
"[Episode- 2363] 200\n",
"Game cleared in 2363, average rewards: 195.15\n"
]
}
],
"source": [
"init = tf.global_variables_initializer()\n",
"sess.run(init)\n",
"\n",
"MAX_EPISODES = 5000\n",
"\n",
"# For checking if the game is cleared\n",
"EPISODE_100_REWARDS = []\n",
"CLEAR_REWARD = env.spec.reward_threshold\n",
"CLEAR_REWARD = CLEAR_REWARD if CLEAR_REWARD else 9999\n",
"\n",
"for episode in range(MAX_EPISODES):\n",
" s = env.reset() \n",
" done = False\n",
" \n",
" s_list = []\n",
" a_list = []\n",
" r_list = []\n",
" \n",
" episode_r = 0\n",
" \n",
" while not done:\n",
" \n",
" s = preprocess_state(s)\n",
" a = agent.choose_an_action(s)\n",
"\n",
" s2, r, done, info = env.step(a)\n",
" \n",
" s_list.append(s)\n",
" a_list.append(a)\n",
" r_list.append(r)\n",
" \n",
" s = s2\n",
" \n",
" episode_r += r\n",
" \n",
" a_list = preprocess_action(a_list, action_n)\n",
" \n",
" agent.train(np.vstack(s_list), a_list, r_list)\n",
" \n",
" print(f\"[Episode-{episode:>6}] {int(episode_r):>4}\", end=\"\\r\")\n",
" \n",
" # For line breaks\n",
" if episode % (MAX_EPISODES // 5) == 0:\n",
" print()\n",
" \n",
" EPISODE_100_REWARDS.append(episode_r)\n",
" \n",
" # Check if the game is cleared\n",
" if len(EPISODE_100_REWARDS) > 100:\n",
" EPISODE_100_REWARDS = EPISODE_100_REWARDS[1:]\n",
" \n",
" avg_rewards = np.mean(EPISODE_100_REWARDS)\n",
" \n",
" if avg_rewards > CLEAR_REWARD:\n",
" print()\n",
" print(f\"Game cleared in {episode}, average rewards: {avg_rewards}\")\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Test run\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode-0] 198\n",
"[Episode-20] 200\n",
"[Episode-40] 200\n",
"[Episode-60] 200\n",
"[Episode-80] 200\n",
"[Episode-98] 200\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2017-04-08 21:13:16,119] Finished writing results. You can upload them to the scoreboard via gym.upload('/home/kkweon/github/ReinforcementZeroToAll/gym-results')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Episode-99] 200\r"
]
}
],
"source": [
"for episode in range(100):\n",
" s = env.reset() \n",
" done = False\n",
" \n",
" episode_r = 0\n",
" while not done:\n",
" if episode % 20 == 0:\n",
" env.render()\n",
" s = preprocess_state(s)\n",
" a = agent.choose_an_action(s)\n",
" s2, r, done, info = env.step(a)\n",
" \n",
" s = s2\n",
" episode_r += r \n",
" \n",
" print(f\"[Episode-{episode}] {int(episode_r)}\", end=\"\\r\")\n",
" \n",
" if episode % 20 == 0:\n",
" print()\n",
" \n",
"env.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment