Created
March 11, 2019 00:41
-
-
Save act65/8f2e4e8c996f76ce87c7fbf88bbc872b to your computer and use it in GitHub Desktop.
temporal difference value estimates
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# you might need to run this?\n", | |
"# !pip install jax jaxlib\n", | |
"\n", | |
"import jax\n", | |
"import jax.numpy as np\n", | |
"from jax import grad, jit, vmap\n", | |
"from jax.experimental.stax import serial, Dense, Relu, Softplus, Tanh, Softmax\n", | |
"from jax.experimental import optimizers\n", | |
"\n", | |
"import numpy.random as rnd\n", | |
"\n", | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mse(x, y):\n", | |
" return np.mean(np.sum((x-y)**2, axis=-1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"width = 128\n", | |
"activation = Relu\n", | |
"n_inputs = 1\n", | |
"gamma = 0.9\n", | |
"lr = 1e-3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/local/scratch/miniconda3/envs/venv/lib/python3.6/site-packages/jax/lib/xla_bridge.py:167: UserWarning: No GPU found, falling back to CPU.\n", | |
" warnings.warn('No GPU found, falling back to CPU.')\n" | |
] | |
} | |
], | |
"source": [ | |
"# a multi layer network\n", | |
"init, fn = serial(\n", | |
" Dense(width), activation,\n", | |
" Dense(width), activation,\n", | |
" Dense(width), activation,\n", | |
" Dense(width), activation,\n", | |
" Dense(1)\n", | |
")\n", | |
"fn = jit(fn)\n", | |
"\n", | |
"out_shape, params = init((-1, n_inputs))\n", | |
"\n", | |
"@jit\n", | |
"def loss_fn(params, x_t, r_t, v_tp1):\n", | |
" # the mean squared bellman error\n", | |
" v_t_approx = fn(params, x_t)\n", | |
" v_t_target = r_t + gamma * v_tp1\n", | |
" return mse(v_t_approx, v_t_target)\n", | |
"\n", | |
"dlossdparam = jit(grad(loss_fn))\n", | |
"\n", | |
"opt_init, opt_update = optimizers.adam(step_size=lr)\n", | |
"opt_state = opt_init(params)\n", | |
"\n", | |
"@jit\n", | |
"def step(i, opt_state, batch):\n", | |
" # SGD algol with adam\n", | |
" params = optimizers.get_params(opt_state)\n", | |
" g = dlossdparam(params, *batch)\n", | |
" return opt_update(i, g, opt_state)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def reward_fn(x):\n", | |
" # a simple step fn.\n", | |
" t = 1e-1\n", | |
" return 0.1*np.greater(x, -t).astype(np.float32) * np.greater(t, x).astype(np.float32)\n", | |
"\n", | |
"def transition_fn(s, a):\n", | |
" return s + a\n", | |
"\n", | |
"def policy(s):\n", | |
" return -np.sign(s)/100 # actions take us toward x=0\n", | |
"# return -0.01 # always go left\n", | |
"# return 0.01 # always go right\n", | |
"\n", | |
"def data_generator(N):\n", | |
" for _ in range(N):\n", | |
" s_t = rnd.random((50, 1))*-1\n", | |
" a_t = policy(s_t)\n", | |
" s_tp1 = transition_fn(s_t, a_t)\n", | |
" r_t = reward_fn(s_tp1)\n", | |
" yield s_t, r_t, s_tp1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x7f59a44be5f8>]" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"N = 100\n", | |
"x = np.linspace(-1, 1, N)\n", | |
"\n", | |
"plt.title('Reward fn')\n", | |
"plt.plot(x, reward_fn(x))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"i: 1999, Loss:0.00207" | |
] | |
} | |
], | |
"source": [ | |
"losses = []\n", | |
"\n", | |
"for i, batch in enumerate(data_generator(2000)):\n", | |
" x_t, r_t, s_tp1 = batch\n", | |
" \n", | |
" v_tp1 = fn(params, s_tp1)\n", | |
" opt_state = step(i, opt_state, (x_t, r_t, v_tp1))\n", | |
" \n", | |
" L = loss_fn(params, x_t, r_t, v_tp1)\n", | |
" losses.append(L)\n", | |
" print(\"\\ri: {}, Loss:{:.5f}\".format(i, L), end='', flush=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def discount(rs, discount):\n", | |
" return np.sum(np.vstack([r*(discount**i) for i, r in enumerate(rs)]), axis=0)\n", | |
"\n", | |
"def play_episode(x, N):\n", | |
" xs = [x]\n", | |
" for _ in range(N-1):\n", | |
" s_t = xs[-1]\n", | |
" # choose action with policy basen on current state\n", | |
" a = policy(s_t)\n", | |
" # then simulate the transition\n", | |
" s_tp1 = transition_fn(s_t, a)\n", | |
" xs.append(s_tp1)\n", | |
" return xs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0.5, 1.0, 'Value')" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VPW9//HXZyaTTPaQBJKQjR3ZZDGCSxXcENSKe6F6r/5qa621tb/+rrf2trW9XvurvXLrvdXa1vrzUmtdqVVULO6tV8SCCAgBQkCWACEbCdmXme/vjzMThpiQkMzMmZl8no9HHpk5c2bOJyfJe77zPd/zPWKMQSmlVGxx2F2AUkqp4NNwV0qpGKThrpRSMUjDXSmlYpCGu1JKxSANd6WUikEa7koFEJExImJEJM7uWpQaCg13FXNE5C8icl8vy5eISKUGtxoONNxVLPo9cJOISI/l/wD80RjTZUNNSoWVhruKRS8BWcB5/gUiMgK4AnhSRC4XkU9E5JiIHBCRn/T1QiKyV0QuDrj/ExF5KuD+WSKyVkTqRWSziCwIxQ+k1KnScFcxxxjTCjwP/GPA4huAHcaYzUCz77EM4HLgGyJy1aluR0TygdeA+4FM4J+AP4nIyKH9BEoNnYa7ilW/B64TEbfv/j/6lmGMec8Y86kxxmuM2QI8A8wfxDZuAlYbY1b7XutNYANwWRDqV2pINNxVTDLG/A9QA1wlIuOBucDTACIyT0TeFZFqEWkAbgeyB7GZYuB6X5dMvYjUA18A8oLzUyg1eDpqQMWyJ7Fa7JOBNcaYI77lTwOPAIuNMW0i8p/0He7NQFLA/dyA2weAPxhjvhbcspUaOm25q1j2JHAx8DV8XTI+qUCdL9jnAl8+yWtsApaKiEtESoDrAh57CviiiFwqIk4RcYvIAhEpCPLPodQp03BXMcsYsxdYCyQDqwIeugO4T0QagXuxDr725UfAeOAo8K/4unZ8r38AWAL8C1CN1ZK/G/2/UhFA9GIdSikVe7SFoZRSMUjDXSmlYpCGu1JKxSANd6WUikG2jXPPzs42Y8aMsWvzSikVlT7++OMaY0y/U1zYFu5jxoxhw4YNdm1eKaWikojsG8h62i2jlFIxSMNdKaVikIa7UkrFIJ04TCkVMp2dnVRUVNDW1mZ3KVHH7XZTUFCAy+Ua1PM13JVSIVNRUUFqaipjxozh81c9VH0xxlBbW0tFRQVjx44d1Gv02y0jIk+ISJWIbO3jcRGRX4pIuYhsEZE5g6pEKRVz2trayMrK0mA/RSJCVlbWkD7xDKTPfQWw6CSPLwYm+r5uA3496GqUUjFHg31whrrf+u2WMcb8TUTGnGSVJcCTxppecp2IZIhInjHm8JAqU2oIjDE8tW4f1Y3tYdvm6IxEls4tCtv2lDqZYPS552PNY+1X4Vv2uXAXkduwWvcUFek/gQqdA3Wt/OjlbQCEo+Honzn74qk5ZKckhH6DakDq6+t5+umnueOOO07peStWrGDhwoWMHj0aOH7SZXb2YK7GaI+wHlA1xjwGPAZQUlKiE8mrkKlqtPoqf/+Vucyf1O+Z2kP22pbDfPPpjVQ3tmu4R5D6+noeffTRz4V7V1cXcXF9x9+KFSuYPn16d7hHo2CE+0GgMOB+gW+ZUrapabK6Y7JT4sOyPf92/NtVkeGee+5h9+7dzJo1C5fLhdvtZsSIEezYsYM33niDK664gq1brbEiy5cvp6mpienTp7NhwwZuvPFGEhMT+fDDDwF4+OGHeeWVV+js7OSFF17gtNNOs/NH61cwwn0VcKeIPAvMAxq0v13ZrbqpA4CRYWpFZ6da29Fw79u/vrKN0kPHgvqaU0en8eMvTuvz8QceeICtW7eyadMm3nvvPS6//HK2bt3K2LFj2bt3b6/Pue6663jkkUdYvnw5JSUl3cuzs7PZuHEjjz76KMuXL+fxxx8P6s8SbAMZCvkM8CEwWUQqRORWEbldRG73rbIa2AOUA7/Duj6lUraqaWxHBDKTw9Vy94V7Y0dYtqcGZ+7cuYMeN37NNdcAcMYZZ/T5xhBJBjJaZlk/jxvgm0GrSKkgqGlqZ0RSPHHO8MywkeaOI97p0Jb7SZyshR0uycnJ3bfj4uLwer3d9/sbU56QYL2BO51Ourq6QlNgEOncMiom1TS1h62/Hawxydkp8VRruEeU1NRUGhsbe30sJyeHqqoqamtraW9v59VXXx3Q86KFTj+gYlJNU0fYR61kpyZQ06TdMpEkKyuLc889l+nTp5OYmEhOTk73Yy6Xi3vvvZe5c+eSn59/wgHSW265hdtvv/2EA6rRRoyxZ0RiSUmJ0Yt1qFCZ/+C7zCzI4JfLZodtm19ZsZ7KhjZW33Ve2LYZ6bZv386UKVPsLiNq9bb/RORjY0xJH0/ppt0yKibV2DDePDslXvvcVcTQcFcxp7XDQ3OHh+zU8PW5gzVipra5A69Xz89T9tNwVzHH33oO1xh3v5GpCXi8hvrWzrBuV6neaLirmFPlmyzMf2JRuHSPddeuGRUBNNxVzLGr5X78RCYNd2U/DXcVc47PKxPubhmrj1/HuqtIoOGuYo5/CoCsMJ7EBIHdMjrWPRqtWLGCQ4cOdd//6le/Smlp6ZBfd+/evTz99NNDfp1TpeGuYk5NUzsZSS5cYZp6wC890YXLKdrnHqV6hvvjjz/O1KlTh/y6Gu5KBYk19UD451QXEbKSE8J69SfVv6eeeoq5c+cya9Ysvv71r+PxeLjllluYPn06M2bM4KGHHmLlypXd0/zOmjWL1tZWFixYgP9Ey5SUFO6++26mTZvGxRdfzN///ncWLFjAuHHjWLVqFWCF+HnnncecOXOYM2cOa9euBaxph99//31mzZrFQw89hMfj4e677+bMM8/k9NNP57e//W1Ifm6dfkDFnHDPKxMoO1VPZOrT6/dA5afBfc3cGbD4gT4f3r59O8899xwffPABLpeLO+64g/vvv5+DBw92z+NeX19PRkZGr9P8+jU3N3PhhRfy4IMPcvXVV/PDH/6QN998k9LSUm6++WauvPJKRo0axZtvvonb7WbXrl0sW7aMDRs28MADD7B8+fLuuWsee+wx0tPTWb9+Pe3t7Zx77rksXLhw0LNV9kXDXcWcmqYOpo1Os2Xb2SkJGu4R5O233+bjjz/mzDPPBKC1tZVFixaxZ88evvWtb3H55ZezcOHCfl8nPj6eRYsWATBjxgwSEhJwuVzMmDGje/rfzs5O7rzzTjZt2oTT6aSsrKzX13rjjTfYsmULK1euBKChoYFdu3ZpuCvVHzumHvDLTklgx+Honk0wZE7Swg4VYww333wzP/vZz05Y/tOf/pQ1a9bwm9/8hueff54nnnjipK/jcrkQ38V4HQ5H9/S/Doeje/rfhx56iJycHDZv3ozX68XtdvdZ08MPP8yll1461B/vpLTPXcWUtk4Pje1djAzzCUx+1hQE7dg1IZ860UUXXcTKlSupqqoCoK6ujn379uH1ern22mu5//772bhxIzD0aX4bGhrIy8vD4XDwhz/8AY/H0+vrXnrppfz617+ms9M6k7msrIzm5uZBb7cv2nJXMcV/MNO2PveUeDo9hobWTjKS7KlBHTd16lTuv/9+Fi5ciNfrxeVy8Ytf/IKrr766+0Id/lb9UKf5veOOO7j22mt58sknWbRoUfeFQU4//XScTiczZ87klltu4a677mLv3r3MmTMHYwwjR47kpZdeCt4P7aNT/qqY8sn+o1z96Fr+380lXDQlp/8nBNnLmw5y17ObeOu75zNhVGrYtx9pdMrfodEpf5Xy8Z9AZFe3jH/Kg2q9lqqymYa7iil2TT3g55+sTEfMKLtpuKuY4p+0K9xTD/jpzJCfpweXB2eo+03DXcWUmqZ20txxJMQ5bdl+RqILp0OnIPBzu93U1tZqwJ8iYwy1tbV9DqccCB0to2JKTVNH2OdxD+RwCFnJ8d2Tlw13BQUFVFRUUF1dbXcpUcftdlNQUDDo52u4q5hSbdO8MoH0LNXjXC5X0M+8VAOj3TIqptQ0tYf9Ih09ZadquCv7abirmGJNPWDvyUPZKfE6p7uynYa7ihntXR6OtXXZ3i0zMsWa9lcPIio7DSjcRWSRiOwUkXIRuaeXx4tE5F0R+UREtojIZcEvVamT87eW7TygClafe4fHy7G2LlvrUMNbv+EuIk7gV8BiYCqwTER6Xp7kh8DzxpjZwFLg0WAXqlR//PPK2N3nPirNf5Zqm611qOFtIC33uUC5MWaPMaYDeBZY0mMdA/gn0E4HDqFUmFU2tAKQlzH4scHBkJtmbb+yQQ+qKvsMJNzzgQMB9yt8ywL9BLhJRCqA1cC3enshEblNRDaIyAYd96qC7VC91VLOS0+0tQ7/9g/53myUskOwDqguA1YYYwqAy4A/iMjnXtsY85gxpsQYUzJy5MggbVopS+WxNhLiHIxIctlaR0661S1T2aDdMso+Awn3g0BhwP0C37JAtwLPAxhjPgTcQHYwClRqoA7Vt5KX7u6+Yo5dEuKcZKfEc1hb7spGAwn39cBEERkrIvFYB0xX9VhnP3ARgIhMwQp37XdRYVXZ0GZ7l4xfXnoih7XlrmzUb7gbY7qAO4E1wHasUTHbROQ+EbnSt9r/Ab4mIpuBZ4BbjA7yVWF2uKGNvHR7D6b65aa7OVyv4a7sM6C5ZYwxq7EOlAYuuzfgdilwbnBLU2rgPF7DkWNtto+U8Rud7uajPbV2l6GGMT1DVcWEmqZ2uryG3AjplslNT+RYWxfN7Xoik7KHhruKCYfqrYOXoyOkW2a07xOEHlRVdtFwVzHBP+wwN0LC3X8ikx5UVXbRcFcx4ZAvREdHSLfM6AyrDj2oquyi4a5iQmVDKwlxDjJsPoHJzz+/jLbclV003FVMONTQxuiMRNtPYPKzTmRK0D53ZRsNdxUTKhvauvu5I0Veultb7so2Gu4qJhyub42YMe5+Vrhry13ZQ8NdRT2P13CksT1izk7105a7spOGu4p61Y3teLwmYuaV8cvLSKSxrYsmPZFJ2UDDXUU9f9dHJLbc4fhFRJQKJw13FfX8XR8R13L3X7RDx7orG2i4q6h3PNwjteWu4a7CT8NdRb3D9a24XZFzApNfjm9opl5uT9lBw11FvcPHrIt0RMoJTH7xcQ6yUxK05a5soeGuot5h3+X1ItHoDHf3vDdKhZOGu4p6kXR5vZ7y0t06WkbZQsNdRbVIPYHJLy89UWeGVLbQcFdRrfsEpgibesAvL91NY3sXjW2ddpeihhkNdxXVDkXoCUx+uTocUtlEw11FtQN1LQDkZyTZXEnvCkZYdR042mJzJWq40XBXUW1/rRWaxVmRGe7+uvbVarir8NJwV1FtX10LOWkJuF1Ou0vpVVZyPMnxTg13FXYa7iqq7attpjgz2e4y+iQiFGUls6+22e5S1DCj4a6i2r7aFooitEvGrzgziX112nJX4aXhrqJWa4eHqsZ2ijMjPNyzk6ioa8XjNXaXooYRDXcVtfb7WsPF2ZHbLQNQnJlMh8dL5TEdDqnCZ0DhLiKLRGSniJSLyD19rHODiJSKyDYReTq4ZSr1ef5+7IhvuXePmNF+dxU+/Ya7iDiBXwGLganAMhGZ2mOdicD3gXONMdOA74SgVqVO0N1yj/A+9yLfm89+HTGjwmggLfe5QLkxZo8xpgN4FljSY52vAb8yxhwFMMZUBbdMpT5vb20zae44MpLi7S7lpEZnJOJyCns13FUYDSTc84EDAfcrfMsCTQImicgHIrJORBb19kIicpuIbBCRDdXV1YOrWCmffbUtFGdFdn87gNMhFI5IYn+ddsuo8AnWAdU4YCKwAFgG/E5EMnquZIx5zBhTYowpGTlyZJA2rYar/XWRPwzSrygrSU9kUmE1kHA/CBQG3C/wLQtUAawyxnQaYz4DyrDCXqmQ6PJ4OXi0NeIPpvoVZyaxv7YFY3Q4pAqPgYT7emCiiIwVkXhgKbCqxzovYbXaEZFsrG6aPUGsU6kTHKpvo8trGBMF3TIARVnJNLZ3cbRFp/5V4dFvuBtjuoA7gTXAduB5Y8w2EblPRK70rbYGqBWRUuBd4G5jTG2oilZqn6//Olq6ZfyfMHQ4pAqXuIGsZIxZDazusezegNsG+K7vS6mQ2xvhs0H2NCb7+OyQs4tG2FyNGg70DFUVlfbXNhMf5yAnNTIv0tFTwYgkRHTqXxU+Gu4qKu2rbaEoMwmHQ+wuZUDcLie5ae7u7iSlQk3DXUWl/XUtUTNSxq/IN2JGqXDQcFdRxxgTNScwBRqTlaxT/6qw0XBXUae6sZ3WTk/UHEz1K8pKorqxneb2LrtLUcOAhruKOv7Wb7QMg/Tzvxnt19a7CgMNdxV1Pqu2DkpGywlMfv56P6vRg6oq9DTcVdQpO9JIQpyjeyrdaDF+ZAoiVv1KhZqGu4o6ZVVNTBiVgjNKhkH6JcY7Kc5MYteRJrtLUcOAhruKOmWVjUzKSbW7jEGZmJPKTm25qzDQcFdRpaG1k8pjbVEb7pNzUtlb00x7l8fuUlSM03BXUWWXr9U7KSfF5koGZ2JOCl1eowdVVchpuKuoUubrr47alnuuVXeZ9rurENNwV1Gl7EgjSfFO8jMS7S5lUMZmJ+N0CGWV2u+uQkvDXUWVsiONTMxJjZoJw3pKiHMyJitJh0OqkNNwV1Gl7EgTk0ZFZ3+73+TcVHZVabeMCi0NdxU16po7qGlq7+63jlYTR6Wyt7aZtk4dMaNCR8NdRQ1/V8bEKD2Y6jc5NxVjoFxb7yqENNxV1PCH++QoD3f/ME7td1ehpOGuokbZkUZS3XHkpCXYXcqQFGclE+906HBIFVIa7ipqlFU2MTknFZHoHCnj53I6GDcyWVvuKqQ03FVUMMZQVtUY9f3tfpNyUjXcVUhpuKuoUN3YTn1LJ5OjdNqBniblpFBxtFWvyqRCRsNdRYVon3agJ/8nEB3vrkJFw11FhR2VxwCYFOVj3P1O8/0c2w8fs7kSFas03FVU2FzRQH5GItkp0T1Sxq8oM4k0dxxbKhrsLkXFKA13FRU2H6hnZmG63WUEjYgwszCDzQfq7S5FxagBhbuILBKRnSJSLiL3nGS9a0XEiEhJ8EpUw11dcwf761o4vSDD7lKCamZBBjuPNNLaodMQqODrN9xFxAn8ClgMTAWWicjUXtZLBe4CPgp2kWp421JhtW5nxlq4F2bg8RpKD2vXjAq+gbTc5wLlxpg9xpgO4FlgSS/r/Rvwc6AtiPUpxeYDDYjAjILY6ZYBmOn7eTYd0HBXwTeQcM8HDgTcr/At6yYic4BCY8xrJ3shEblNRDaIyIbq6upTLlYNT5sr6pkwMoWUhDi7SwmqUWlu8tLd2u+uQmLIB1RFxAH8Avg//a1rjHnMGFNijCkZOXLkUDethgFjDFsq6plZGFtdMn4zCzK6u52UCqaBhPtBoDDgfoFvmV8qMB14T0T2AmcBq/SgqgqGg/Wt1DR1dHdhxJrTC9PZW9tCfUuH3aWoGDOQcF8PTBSRsSISDywFVvkfNMY0GGOyjTFjjDFjgHXAlcaYDSGpWA0r/nHgsdpyn+U7SKzj3VWw9Rvuxpgu4E5gDbAdeN4Ys01E7hORK0NdoBreNh+oJ97p4LTcNLtLCYnpvk8k2u+ugm1AR6iMMauB1T2W3dvHuguGXpZSlk0H6pkyOo34uNg83y7N7WL8yGQ2a7+7CrLY/I9RMcHjNWw92BCz/e1+Mwsy2HSgAWOM3aWoGKLhriLW7uommjs8MXfyUk8zCzOoaWrncIOeIqKCR8NdRSx/P3SsHkz18/982u+ugknDXUWs9XvrSE90MS472e5SQmpKXioJcQ7W7z1qdykqlNqbYNPTsOIK2P1uyDcXW6f8qZhhjOGD8lrOHpeFwxHd10ztT0KckzPHZLJ2d43dpahg83ph3wdWqJe+DJ3NMGIsdDSHfNMa7ioiHahr5WB9K1+fP87uUsLi7PFZPLhmJzVN7TEzZ/2wdnQvbHoGNj8N9fshIQ1mXAuzboTCeRCGi7xruKuI5G/FnjM+y+ZKwsP/c67bU8sVp4+2uRo1KB3NVut809Ow931AYNx8uPBHcNoVEJ8U1nI03FVEWru7llGpCYwfGRsXxO7PjPx0UhPiWLtbwz2qGAP7P4RP/gilL0FHk9XtcsEPYeaXIKPIttI03FXEMcawdnctX5iQhYTh42skiHM6mDcuk7Xl2u8eFeoPwOZnYdMf4ehnEJ8C066yul2Kzg5Lt0t/NNxVxNlV1URNUzvnjM+2u5SwOnt8Nm9tr+JgfSv5GYl2l6N66myF7a/Cpqdgz18BA2POg/n/DFOuhITI+pSp4a4ijr/1evYw6W/38/e7f7i7luvOKLC5GgVY3S4VG6xA3/oitB+zulrmfw9mLYMRY+yusE8a7irifLC7lqLMJAozw3sAym6Tc1LJSo5nbXmNhrvdGg7Cluesg6O1u8CVZLXOZ98IxV8AR+SfIqThriKKx2tYt6eWy2fk2V1K2Dkcwlnjs1i7uxZjzLA53hAxOpphx2tWoO95DzBW//m5d1n96Qmpdld4SjTcVUTZdqiBxrauYdcl43fO+Cxe23KYz2qaGTdMRgrZytNlDVv89AVrGGNHk6/b5Z9h5lLIjN7zLDTcVUT5oLwWGH797X7+g8gf7K7VcA8Vrwf2rYVtf4btq6C5GuJTYdrVMHOZ1VqPgm6X/mi4q4jy9vYjTMlLY1Sq2+5SbDEmK4mizCTe3n6Efzir2O5yYoc/0EtfgtJV0Fxl9aNPutQK9YkLwRVbI5Q03FXEqGps4+P9R/nORZPsLsU2IsKl03JYsXYvx9o6SXO77C4penm9cGCd1UIvfRmajkBcIkxaeDzQ42N3UjoNdxUx3iw9gjGwaHqu3aXYatH0XH73/me8u6OKJbPy7S4nung9cOAj2PaSL9Arh1WgB9JwVxHjL1srGZudzKSc4d3XPLtwBCNTE1izrVLDfSC6u1xetvrQm45AnBsmXgJTr4JJiyLuBKNw0HBXEaGhpZMPd9dy63ljh/0QQIfD6pr508cHaev04HY57S4p8vhHuZS+DNtfgZYaq4U+8RKYumTYBnogDXcVEd7ecYQur2HRtOHdJeO3aFoeT63bz9/Kqlmo+8Ti6YTP/uoL9FehtQ5cydZB0alXDqsul4HQcFcRYc22SnLT3DF/vdSBmjcuk/REF3/ZVjm8w72r3TqhqPRl6wSjtnpr2OLkRdYZoxMuDvtUutFCw13ZrqWji7+WVfOlksKYv+rSQLmcDi6aMoq3So/Q6fHickb/uOsB62iG8resIYtla6CjERLS4bTLrEAffyG4hudQ2VOh4a5s97eyato6vVw6nFuovVg0LZcXNx5k3Z5azps40u5yQqutwQry0peh/G3oaoXETOu0/6lLYOx8iIu3u8qoouGubPfqlsNkJLmYOzbT7lIiyvmTRpIU7+S1LYdjM9yba2Hna1YLfc974O2E1DyYfRNM+SIUnwtOjajB0j2nbFXX3MEb247w5XlFxA2nrocBcLucXDYjj1c2H+JHV0wlOSEG/l0bj8COV6xA3/s/YDzWXC7zvm610PNLYuLU/0gQA38tKpq9uLGCDo+XpXML7S4lIi2bW8jKjyt4ZfMhls6175JtQ1J/wBquuH0V7F8HGMiaCF/4jtWHnjczIq5cFGsGFO4isgj4L8AJPG6MeaDH498Fvgp0AdXAV4wx+4Jcq4oxxhieW3+AWYUZnJabZnc5EWlO0Qgmjkrh2fUHoivc6/ZYrfPSl+HQRmtZznRYcI8V6KOmaKCHWL/hLiJO4FfAJUAFsF5EVhljSgNW+wQoMca0iMg3gH8HvhSKglXs2Lj/KLuqmvj5tTPsLiViiQhL5xbxb6+WsqPyWGS/CVbvtMK8dBUc+dRaNno2XPRjq8sla7y99Q0zA2m5zwXKjTF7AETkWWAJ0B3uxph3A9ZfB9wUzCJVbHrm7wdIjndyxemj7S4lol0zO5+fv76DZ/9+gJ9cOc3uck5UvdO6/FzpS1C9w1pWOA8u/b/WQdGMKPq0EWMGEu75wIGA+xXAvJOsfyvwem8PiMhtwG0ARUX6Sx/OjrV18uqWQ1w9uyA2DhSG0IjkeC6dnsuLGyu4Z/Fp9k9H0FQFn66ELc/C4c2AQPE5sPhBK9DTht9VtCJRUP+rROQmoASY39vjxpjHgMcASkpKTDC3raLLy5sO0dbpZZkeSB2QZWcW8srmQ7y+9TBXz7bh+qr+S9BteQ52v2uNcsmbBYsesGZbTNVzFCLNQML9IBD4H1jgW3YCEbkY+AEw3xjTHpzyVCzyeg2/X7uXaaPTmJGfbnc5UeGscVmMzU7mvz/Yy1Wz8sMzuVp7I+x6w+pH3/UmdLZAepF1TdGZS2Hk5NDXoAZtIOG+HpgoImOxQn0p8OXAFURkNvBbYJExpiroVaqY8pdtlZRXNfHwstnDfgbIgXI4hNvOH8f3X/yUv+2qYf6kEJ3U1NlqBfqnK63vXW2QkmNdfm7GdVB4lo5DjxL9hrsxpktE7gTWYA2FfMIYs01E7gM2GGNWAQ8CKcALvn/W/caYK0NYt4pSxhgefqeccSOTuWyG9s2eimvnFPDw27t4+O1dnD8xO3hvjJ4ua7bFT1da49E7GiF5FMz5R6vLpXAeOHTa4WgzoD53Y8xqYHWPZfcG3L44yHWpGPX29iq2Hz7GL26YiVMnCTsl8XEObl8wnntf3sa6PXVDu4i4MVCxAbautEa7NFdZk3NNWwIzrofiL+ip/1FOf3sqbKxW+y6KMpO4cqYOfxyMG0oKefidch5+Z9eph7vXCxXrYcerVj96/T5wJliXoJtxg+8i0TrbYqzQcFdh87ddNWyuaOCBa2boPDKD5HY5+fr547j/te18vK+OM4r7mWytrcGalGvXG9ZB0aYj4HDB2PNh/vdgyhXg1oPasUjDXYWF12t46M0yRqe7uWaODUP5YsiX5xXx6Hu7+Y83yvjjV+ed2Pfu9cDhTbD7HWvI4oGPwNtldbmMvwBOu8K6FF2iXhQl1mm4q7BYubGCTQfqWX79TOLjtNU+FEnxcXz7wgn85JVSXt9ayWUF7VaQ73kXPvsbtB61VswYKzSXAAAU00lEQVSbCed82wrzgrnahz7M6G9bhVxDSycPvL6DkuIRXDM73+5yol9zLf+Q9gkj057j9D/9b+CItTx1NEy+DMYtgHEXQEoMzgGvBkzDXYXc8jd2Ut/SwX1L5ull9AajvQn2f2j1nX/2V6jcihPDIlcK73gmsX38TVxy+Zcge5LOtKi6abirkNp6sIGnPtrHzWePYeroCJ7RMFJ4vVBbbk2Te+gTOLjRuu3tAme81b1ywb/A2Pk48+ew5sVSXt50kNdNPhM02FUADXcVMp0eLz/486dkJcfzvy+ZZHc5kam5Fg5+DAc3WOPOD26wRrgAuJIg93Q4+5vWNUSLzob4pBOefs/i01izrZIf/PlTnv7aWXrugOqm4a5CZvmanWyuaODRG+eQnuiyu5yh83rg2EForLQCuK3BOnjZUgcttdZcLA6HNdTQEWe1tJ0uiEvwfblBHHB0n9U6r9kJ9fut1xYHjJpqnRGaXwL5Z1jdLP0cBM1OSeBHV0zln1du4dF3y/nWRRPDsCNUNNBwVyHx7o4qfvu3Pdw4ryhypxnoaIaGg9BUaU1j21zt+14FzTVWX7fxWF0irUetIPZ09P5a7nRISLPeALxd1sWePV3W+p4e8+i5kqwLV+SXwJlftYI8bxYkpAzqx7j+jALWltfw0FtlzB2bybxxQzhzVcUMDXcVdIcbWvnu85s4LTeVH10x1b5CPF1wrMJqKdfvg7rPrMu/1e2BhgPHhwwGEickj7RGmsSnWi1vVyKkjbbGiGeOhbQCK8z9X0mZ1np9McYK+c5WK/gTM4M6+ZaIcP/VM9hc0cC3n/2E1d8+j6yUhKC9vopOGu4qqNo6PXz7mU9o7/LyyJfn2HdhieZaeOSMEwPcEQcZxZA5DgrOhPR8K6hTcyFllDVZVuKI4M96KHK8ayZEUhLieOTLs7n60bV857lNPHHLmbj0LOBhTcNdBY3Ha/jOs5tYv/cov1w2mwmjBtfNEBS71ljBfvG/WtfxHFFsBXkMn8gzbXQ69105jXte/JTv/WkLy6+bqUNPh7HY/UtXYWWM4YcvbeUv2yr50RVT7Z8YbOfrkJpnXVhiGA0RXDq3iCPH2nnorTKykuP5l8um6Jz5w5SGuxoyYwz/8UYZz/x9P9+8YDy3fmGsvQV1dVhzq8y4blgFu9+3L5pAXXM7v3v/M0Ykx3PHggl2l6RsoOGuhsTjNfzbq6WsWLuXZXML+aeFEXDptX3/Ax1NMGmx3ZXYQkT48RencbSlk3//y04a27q4e+Fk7aIZZjTc1aC1dXq469lPWLPtCF87byzfXxwhXQBla6wx5WPPt7sS2zgcwi9umEmKO45fv7ebw/Wt/Pt1OmnbcKLhrgblUH0rdz69kU8O1HPvFVP5it1dMX7GWP3t4xZ87mzO4SbO6eCnV00nPyORB9fspPJYG79cNptRqXpBjuFA38bVKVv96WEW/9f77Khs5FdfnhM5wQ5QvdMa0z7pUrsriQgiwjcvmMAvbpjJJ/vrWfyf7/PuDr2G/XCg4a4GrKGlk++t3MIdf9zImKwkXvv2eZF39mnZ69b3SYvsrSPCXDOngFe+9QVGpibwv1as58cvb6WpvcvuslQIabeM6pfHa3jm7/v5jzd2Ut/ayTcWjOe7l0yKzJNkytZYk22l6TVae5qUk8pL3zyXn/9lB//9wV5Wb63kny+dzLVzCvRgawzScFd98noNb5Qe4T/fKmNHZSPzxmZy7xenMm10hF5zs7nWuqzc+XfbXUnEcruc/PiL01gyK5+frNrG3Su38NS6fdx18UQumDwqMg6Iq6DQcFef09bpYdXmQ/zmr7vZU91McVYSj944h8XTcyP7n//TF8B4YcoX7a4k4s0qzODFb5zDy5sPsnxNGV9ZsYHTclO5ff54Fs/IJSHOpmkjVNCIMcaWDZeUlJgNGzbYsm31ecYYtlQ08MLHB3h50yEa27qYmpfGNxaM57IZeZE/T7gx8OtzrCGQt71rdzVRpdPj5ZXNh/j1e7vZVdVERpKLq2blc31JAVPz0iL7DX0YEpGPjTEl/a2nLfdhrKPLy4Z9dbxZeoQ3S49QcbSVhDgHi6fnckNJIWePz4qef+yKDVBVCl/8L7sriToup4Nr5hRw1ax83i+v4fkNB3j6o/2sWLuXMVlJXDI1h4un5DCneERkHmdRvdJwH0YaWjvZdrCBjfuPsm5PHRv21dHW6SU+zsF5E7L51oUTWDwjjzR3FF5Y4+MVEJ8C06+1u5Ko5XAI8yeNZP6kkdS3dPDqlsO8UXqEFWv38rv3PyMp3knJmEzOHpfF7KIMpo1OIzUa/1aGiQGFu4gsAv4LcAKPG2Me6PF4AvAkcAZQC3zJGLM3uKWqgWpo7eRAXQt7apopP9LIrqomth8+xt7alu51TstNZemZRZw9PovzJmaTFB/F7/Ntx2DbizDjekhItbuamJCRFM9NZxVz01nFNLZ18j+7avhwTy3r9tTy87/sAKxpe8ZmJzMlN40Jo1KYmJPC2OxkijKTNPQjQL//0SLiBH4FXAJUAOtFZJUxpjRgtVuBo8aYCSKyFPg58KVQFDxcdXR5qW/toKGlk/rWTmqbOqhtbqemsYOqxjYqG9qoPNZGxdFWGlo7u5/nECjOSmZybirXlxQyIz+d0wvSyUiKt/GnCbJPX4DOFjjjZrsriUmpbheLZ+Sx2HdOQ11zB1sq6vm0ooEtBxvYeqiB1VsPE3j4bkSSi/wRieSmuclNdzMq1U12SgJZKfFkJceTkeQiPTGe9ESXTokQIgNprs0Fyo0xewBE5FlgCRAY7kuAn/hurwQeERExIThae6ytk4aW4+EVuAWD6b4fuGF/GeaE9a11/cv8z+3ttteA11jLvMbg9R5f5jUGj9f/3RoT7vEaurxe67vH0On1Wt89Xjp93zu6vLR3eWjv8tLe6aWty0Nbp4fWTi+tHV00t3to6eiiqb2LY21ddHR5+9wnI5Jc5KYnkpfuZlZhBsVZSRRlJlGclczY7GT7LpgRLh+vgNwZMHqO3ZUMC5nJ8SyYPIoFk0d1L2vr9LC7uol9tS3sr7O+Dte3UnG0lQ37jlIf8D/bU0Kcg1S3i1R3HEnxTt9XHG6Xg0SXE7fLSUKcgwSXk3ing/g4B67u70Kcw0GcQ4hzCk6Hdd/pEN8XOMS67RD/l9UF1X1bBBEQfN/7uo1/klHrOJR/mXX7+LGp4+tZz+1eHnD4KiPJFfJPNwMJ93zgQMD9CmBeX+sYY7pEpAHIAmqCUWSgZz7az89e3xHslw07h0BCnJMEl4N4p4PEeCfuOCdul4PkhDhGZ8STFO8kxR1HqjuO1IQ40pPiyUh0kZ7oIjM5npGpCYxIih8+LR+v17o8Xu0uqC6Dqm1wZBtUboHLlg/L6X0jhdvlZNro9D7Pgejo8lLX3EFNUzu1zR00tHbS0NJBfUtndwOmqb2ru2FT39JBW6eX1k6r0dPhOd4IsmmAX1Ddf9V0bjqrOKTbCGtHq4jcBtwGUFRUNKjXWDB51OeuDxn4L+1/t7WW9/6u6X+X9b/D9vbO7BBrPQFrGKDvHd4h4BRB/Lcd1m2nQ3D6v/taEXG+2y6nv2XhIMHX6oj4oYV28Hqsi1QfOwiNlXDsEDRUQN1uqN1tXfu0q+34+knZkDPNOmlp9j/YV7fqV3ycg9x0q4tmqDxe69Nve5eXLo+XLt99a/nxT81eL3h8n6xN9yds65O8xwR8+vYGfoq3lvf5yd5XQ2CnRJ+9BycsP9Hsoowh74f+DCTcDwKFAfcLfMt6W6dCROKAdKwDqycwxjwGPAbWOPfBFDw5N5XJuXrQLOp4PdBUZQX2sQpoOGiF+LFDvq+D0HjYuoB0IIfLuih15ngYfyFkTYDsiZA9ybruqRp2rAaUM/a7G4doIOG+HpgoImOxQnwp8OUe66wCbgY+BK4D3glFf7uKIF4vtNZBcw201AZ81UBL3YnLGo9A0xEwnhNfIy7RmgMmbTQUnwNp+b77+ZCWZ10mL3kkOPSfWKlT1W+4+/rQ7wTWYA2FfMIYs01E7gM2GGNWAf8P+IOIlAN1WG8AKpq1HbOmzq3fD0f3WX3d9futbpKmI1YrvGdY+8WnQFLW8a9RU62gTsvzBXc+pBdA4gjtJ1cqRAbU526MWQ2s7rHs3oDbbcD1wS1NhVxnm9WPXVvu+9ptfa/bbfV9B3IlQXohZBRC7nRIybG+krMhMROSMq0+8KQscOnFIJSyWxSfuaIGpHuESfnnv+oPcMKhnpQcq0978mLIHAcjxkBGsfWVlKmtbKWiiIZ7rGhvhJpd1lftLqgpgxpfKzxwhEl8KmSNh8J5MOtGK8yzJljL9OxOpWKGhns08XRBw34rtGv9QV5ufW+qPL6eOKxWd9ZEGH/B8REmWROtESbaAlcq5mm4RxpjrAOWJ3Sh+PvCPwNvwJl+7gxrSOCEi04M8MyxEJfQ9zaUUjFPw90ubQ2+0PYFd+2u40He0XR8PWeC1WWSPQlOu9wK8czxVpAnZWkrXCnVKw33UOpsg6Ofndj69n9vDrwCvUBGkRXchfOs1nfWeOt+eoGO81ZKnTIN96HydFrjv/3BXecP8T3WKJXA0SjJo6zAnnSpL7wnWvdHjNHhg0qpoNJwHwiv1zplvrvlvfv4fCf1+048Zd6dbnWbFJ0FWTf6ulB8XSnuNPt+BqXUsKLh7nfCgUz/Acw9xyes8rQfX9eVZIV17nSYdtXxfvCs8doPrpSKCMMr3I2x5j2pCziQWRcwEuWEA5nx1ok8meNh4sUnBnhqnga4UiqixWa4tzf2aH0HtMbb6o+vJ04YUWyFdvG5vgAfZwV4eqEeyFRKRa3oDffeRqL4g7zpyInrphVYgT39Wut7pm8kyohicOq1HpVSsSf6wn3jk/DXB/sYiTIeJl5yvPvE3xJ3JdpWrlJK2SH6wj151PGRKIHdKO7eL++llFLDUfSF++RF1pdSSqk+DZMrKyul1PCi4a6UUjFIw10ppWKQhrtSSsUgDXellIpBGu5KKRWDNNyVUioGabgrpVQMEmNM/2uFYsMi1cC+QT49G6gJYjnBonWdGq3r1EVqbVrXqRlKXcXGmJH9rWRbuA+FiGwwxpTYXUdPWtep0bpOXaTWpnWdmnDUpd0ySikVgzTclVIqBkVruD9mdwF90LpOjdZ16iK1Nq3r1IS8rqjsc1dKKXVy0dpyV0opdRIa7kopFYMiNtxF5HoR2SYiXhHpc8iQiCwSkZ0iUi4i9wQsHysiH/mWPyci8UGqK1NE3hSRXb7vI3pZ5wIR2RTw1SYiV/keWyEinwU8NitcdfnW8wRse1XAcjv31ywR+dD3+94iIl8KeCyo+6uvv5eAxxN8P3+5b3+MCXjs+77lO0Xk0qHUMYi6visipb7987aIFAc81uvvNEx13SIi1QHb/2rAYzf7fu+7ROTmMNf1UEBNZSJSH/BYKPfXEyJSJSJb+3hcROSXvrq3iMicgMeCu7+MMRH5BUwBJgPvASV9rOMEdgPjgHhgMzDV99jzwFLf7d8A3whSXf8O3OO7fQ/w837WzwTqgCTf/RXAdSHYXwOqC2jqY7lt+wuYBEz03R4NHAYygr2/Tvb3ErDOHcBvfLeXAs/5bk/1rZ8AjPW9jjOMdV0Q8Df0DX9dJ/udhqmuW4BHenluJrDH932E7/aIcNXVY/1vAU+Een/5Xvt8YA6wtY/HLwNeBwQ4C/goVPsrYlvuxpjtxpid/aw2Fyg3xuwxxnQAzwJLRESAC4GVvvV+D1wVpNKW+F5voK97HfC6MaYlSNvvy6nW1c3u/WWMKTPG7PLdPgRUAf2egTcIvf69nKTelcBFvv2zBHjWGNNujPkMKPe9XljqMsa8G/A3tA4oCNK2h1TXSVwKvGmMqTPGHAXeBIJ1fcxTrWsZ8EyQtn1Sxpi/YTXm+rIEeNJY1gEZIpJHCPZXxIb7AOUDBwLuV/iWZQH1xpiuHsuDIccYc9h3uxLI6Wf9pXz+D+unvo9kD4lIQpjrcovIBhFZ5+8qIoL2l4jMxWqN7Q5YHKz91dffS6/r+PZHA9b+GchzQ1lXoFuxWn9+vf1Ow1nXtb7fz0oRKTzF54ayLnzdV2OBdwIWh2p/DURftQd9f9l6gWwReQvI7eWhHxhjXg53PX4nqyvwjjHGiEifY0l978gzgDUBi7+PFXLxWGNdvwfcF8a6io0xB0VkHPCOiHyKFWCDFuT99QfgZmOM17d40PsrFonITUAJMD9g8ed+p8aY3b2/QtC9AjxjjGkXka9jfeq5MEzbHoilwEpjjCdgmZ37K2xsDXdjzMVDfImDQGHA/QLfslqsjztxvtaXf/mQ6xKRIyKSZ4w57AujqpO81A3An40xnQGv7W/FtovIfwP/FM66jDEHfd/3iMh7wGzgT9i8v0QkDXgN6419XcBrD3p/9aKvv5fe1qkQkTggHevvaSDPDWVdiMjFWG+Y840x7f7lffxOgxFW/dZljKkNuPs41jEW/3MX9Hjue0GoaUB1BVgKfDNwQQj310D0VXvQ91e0d8usByaKNdIjHusXucpYRyjexervBrgZCNYngVW+1xvI636ur88XcP5+7quAXo+qh6IuERnh79YQkWzgXKDU7v3l+939GasvcmWPx4K5v3r9ezlJvdcB7/j2zypgqVijacYCE4G/D6GWU6pLRGYDvwWuNMZUBSzv9XcaxrryAu5eCWz33V4DLPTVNwJYyImfYENal6+207AOTn4YsCyU+2sgVgH/6Bs1cxbQ4GvABH9/BftocbC+gKux+p3agSPAGt/y0cDqgPUuA8qw3nl/ELB8HNY/XznwApAQpLqygLeBXcBbQKZveQnweMB6Y7DejR09nv8O8ClWSD0FpISrLuAc37Y3+77fGgn7C7gJ6AQ2BXzNCsX+6u3vBaub50rfbbfv5y/37Y9xAc/9ge95O4HFQf5776+ut3z/B/79s6q/32mY6voZsM23/XeB0wKe+xXffiwH/lc46/Ld/wnwQI/nhXp/PYM12qsTK79uBW4Hbvc9LsCvfHV/SsBIwGDvL51+QCmlYlC0d8sopZTqhYa7UkrFIA13pZSKQRruSikVgzTclVIqBmm4K6VUDNJwV0qpGPT/AbxE03dyHnGBAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# evaluate the value of all x positions.\n", | |
"# simulate trajectories and (dis)count up the rewards\n", | |
"trajectories = np.vstack(play_episode(x, 200))\n", | |
"rs = reward_fn(trajectories)\n", | |
"vs = discount(rs, gamma)\n", | |
"plt.plot(x, vs, label='truth')\n", | |
"\n", | |
"# the learned fn\n", | |
"y = fn(optimizers.get_params(opt_state), x.reshape((N, 1)))\n", | |
"plt.plot(x, y, label='estimate')\n", | |
"\n", | |
"plt.legend()\n", | |
"plt.title('Value')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x7f59795404e0>]" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure()\n", | |
"plt.title('Loss')\n", | |
"plt.plot(losses)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python (venv)", | |
"language": "python", | |
"name": "venv" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment