-
-
Save kdneal/c86ce4c00e547efa362854ed729f5e0f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import keras\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from keras.layers import Dense, Dropout, Activation\n", | |
"from keras.models import Sequential\n", | |
"import keras.backend as K\n", | |
"import tensorflow as tf\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define model & data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.legend.Legend at 0x1ee96adf888>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deVyVZf7/8dd1WEUQFNBU4BzccgcVt9TMFtMWbTHT0LKxyJZplrJyKC1nmG/LTIuV42Bji0OaWU1WlrZoWbkrue+yu+ACLuyc6/cHyA/hIEc4cJ9z+DwfDx8PuM997vt9e/TDzXVfi9JaI4QQwvWZjA4ghBDCMaSgCyGEm5CCLoQQbkIKuhBCuAkp6EII4SY8jTpxSEiItlgsRp1eCCFc0ubNm09orUNtvWZYQbdYLGzatMmo0wshhEtSSqXW9Jo0uQghhJuQgi6EEG5CCroQQrgJw9rQhesrLi4mIyODgoICo6MIO/j6+hIWFoaXl5fRUUQDkYIu6iwjI4OAgAAsFgtKKaPjiEvQWnPy5EkyMjKIjIw0Oo5oILU2uSilFiiljiuldtTwulJKzVFKHVBKbVNK9XV8TOGMCgoKCA4OlmLuApRSBAcHy29Tbs6eNvT3gFGXeH000Ln8Txzwr/rHEq5CirnrkM/K/dVa0LXWPwGnLrHLWOADXWYdEKSUauuogFUlp+fw4td7kGl/hRCuxmrVJHy1i51ZuQ1yfEf0cmkPpFf6PqN8WzVKqTil1Cal1Kbs7Ow6nWx7Rg7zfjzIjswzdXq/cB85OTnMnTu3Tu+96aabyMnJueQ+M2fO5LvvvqvT8evLnnOvXr2aX3/9tZESCUdYf/gU89ccZt+xsw1yfEcUdFu/x9m8fdZaJ2qtY7TWMaGhNkeu1mpMdHt8vUws2phWp/cL93Gpgl5aWnrJ9y5fvpygoKBL7jN79myuv/76OuerD3vOLQXd9SzZlE6AryejezZMI4YjCnoGEF7p+zAgywHHtSmwmRc39WrLsuQs8opKGuo0wgU888wzHDx4kOjoaKZPn87q1asZMWIE99xzD7169QLgtttuo1+/fvTo0YPExMSK91osFk6cOEFKSgrdunXjwQcfpEePHowcOZL8/HwApkyZwtKlSyv2nzVrFn379qVXr17s2bMHgOzsbG644Qb69u3LQw89hNls5sSJE9Wy+vv788QTT9C3b1+uu+46LvyGmpyczKBBg+jduze33347p0+ftuvcKSkpzJs3j9dee43o6GjWrFnDxx9/TM+ePYmKiuLqq69uoL91UVe5+cUs336EMVHt8PXyaJBzOKLb4jLgMaXUYmAgkKu1PuKA49ZoQv8IPt2SyZfbjjA+Jrz2N4gG98IXO9mV5dhmsO7tWjDr1h41vv7iiy+yY8cOkpOTgbI71g0bNrBjx46KrnkLFiygVatW5Ofn079/f+68806Cg4MvOs7+/ftZtGgR8+fPZ/z48XzyySdMmjSp2vlCQkLYsmULc+fO5R//+AfvvPMOL7zwAtdeey0zZszgm2++ueiHRmXnz5+nb9++/POf/2T27Nm88MILvPXWW9x77728+eabDB8+nJkzZ/LCCy/w+uuv23XuadOm4e/vz5NPPglAr169WLFiBe3bt6+1OUk0vs+TMykssTKhf0SDncOebouLgLXAlUqpDKXUVKXUNKXUtPJdlgOHgAPAfOCRBktbrr+lJR1Dm/PRxvTadxZNyoABAy7qZz1nzhyioqIYNGgQ6enp7N+/v9p7IiMjiY6OBqBfv36kpKTYPPYdd9xRbZ+ff/6ZCRMmADBq1Chatmxp870mk4m7774bgEmTJvHzzz+Tm5tLTk4Ow4cPB+C+++7jp59+svvcVQ0ZMoQpU6Ywf/78WpucROPSWrNoQzo927egV1hgg52n1jt0rfXEWl7XwKMOS2QHpRQT+keQsHw3+46dpUubgMY8vbDhUnfSjal58+YVX69evZrvvvuOtWvX4ufnxzXXXGOzH7aPj0/F1x4eHhVNLjXt5+HhQUlJWXNfXXtbXW4XQlvnrmrevHmsX7+er776iujoaJKTk6v9NiKMsS0jl91HzvC323o26Hlcdi6XO/q2x8tDsXiD3KU3VQEBAZw9W3NvgdzcXFq2bImfnx979uxh3bp1Ds8wdOhQlixZAsDKlSsr2sCrslqtFW3iH374IUOHDiUwMJCWLVuyZs0aABYuXFhxt26Pqtd/8OBBBg4cyOzZswkJCSE9Xf5vOIvFG9No5uXB2Oh2DXoely3owf4+jOx+BZ9uzaCwRH69bIqCg4MZMmQIPXv2ZPr06dVeHzVqFCUlJfTu3ZvnnnuOQYMGOTzDrFmzWLlyJX379uXrr7+mbdu2BARU/42xefPm7Ny5k379+vHDDz8wc+ZMAN5//32mT59O7969SU5Orthuj1tvvZXPPvus4qHo9OnT6dWrFz179uTqq68mKirKYdcp6u5cYQmfJ2dxS++2BPg27Dw6yqgBOjExMbq+C1z8tC+bexdsYM7EPoyJatiffKK63bt3061bN6NjGKqwsBAPDw88PT1Zu3YtDz/8cMVD2sr8/f05d+6cAQkvJp9Z41u0IY0Zn27nk4evop/Z9jOWy6GU2qy1jrH1mktPzjW0Uwjtg5rx0cY0KejCEGlpaYwfPx6r1Yq3tzfz5883OpJwMos3pNGljT99Iy497sERXLqgm0yKu/uH8+q3+0g9eR5zcPPa3ySEA3Xu3JmtW7fWup8z3J2Lxrcr6wy/ZeQy85bujTKXjsu2oV9wV0wYJlU2AksIIZzJ4o1peHuauKOvzdlQHM7lC3rbwGaMuLI1H2/KoKTUanQcIYQAIL+olM+2ZjK65xUE+Xk3yjldvqAD3N0/nONnC1m1t24TfgkhhKMt336EswUlDToytCq3KOjXdm1N6wAfFm+QCbuEEM5h8cY0IkOaM6hDq0Y7p1sUdE8PE+P6hbFq73GO5Noe5SfcT32mzwV4/fXXycvLq3W/1atXc8stt1xyn+TkZJYvX17nLMK9HDh+lo0pp7m7f3ijLiziFgUdYOKACDTIyFEnlpSUhMViwWQyYbFYSEpKqtfxGqug20MKuqhs8YZ0PE2KO/uGNep53aagh7fyY3iXUBZvTKNYHo46naSkJOLi4khNTUVrTWpqKnFxcfUq6lWnzwV45ZVX6N+/P71792bWrFlA2UyHN998M1FRUfTs2ZOPPvqIOXPmkJWVxYgRIxgxYkS1Y3/zzTd07dqVoUOH8umnn1Zs37BhA1dddRV9+vThqquuYu/evRQVFTFz5kw++ugjoqOj+eijj2zuJ5qGguJSPtmSwQ3d2xAa4FP7GxxJa23In379+mlHW7nzqDY//aX+enuWw48tqtu1a5fd+5rNZk3ZwicX/TGbzXU+/+HDh3WPHj0qvl+xYoV+8MEHtdVq1aWlpfrmm2/WP/74o166dKl+4IEHKvbLycmpyJSdnV3tuPn5+TosLEzv27dPW61Wfdddd+mbb75Za611bm6uLi4u1lpr/e233+o77rhDa631u+++qx999NGKY9S0n9Eu5zMTdfPJ5nRtfvpL/cv+6v+2HAHYpGuoqy49sKiqa7u2pl2gL0nr0xjVQCuCiLpJS7P9wLqm7XWxcuVKVq5cSZ8+fYCywTz79+9n2LBhPPnkkzz99NPccsstDBs27JLH2bNnD5GRkXTu3Bkom+72wjznubm53Hfffezfvx+lFMXFxTaPYe9+wv0sXJdKh9DmDO7Y+DNduk2TC4CHSTFxQARr9p/g8InzRscRlURE2O66VdP2utBaM2PGDJKTk0lOTubAgQNMnTqVLl26sHnzZnr16sWMGTOYPXt2rceq6UHWc889x4gRI9ixYwdffPGFzel4L2c/4V52ZuWyNS2H2IHmRn0YeoFbFXQo65PuaVJ8uD7V6CiikoSEBPz8/C7a5ufnR0JCQp2PWXX62BtvvJEFCxZUDLPPzMzk+PHjZGVl4efnx6RJk3jyySfZsmWLzfdf0LVrVw4fPszBgwcBWLRoUcVrubm5tG9fNurvvffeqzFLTfsJ9/bfdWn4epkY18gPQy9wu4LeuoUvI3u04ePNGRQUy7S6ziI2NpbExETM5rI7F7PZTGJiIrGxsXU+ZtXpc0eOHMk999zD4MGD6dWrF+PGjePs2bNs376dAQMGEB0dTUJCAs8++ywAcXFxjB49utpDUV9fXxITE7n55psZOnQoZrO54rWnnnqKGTNmMGTIkItWBRoxYgS7du2qeCha037CPSUlJRERYebFO6M4Mm8qX362xJAcLj19bk1+PXCCe95Zz6vjo7jDoJ+UTYFMxep65DNzvAs9uCp3gfXz86v3DUtNLjV9rtvdoQMM7hhMh5Dm/HedNLsIIRpWfHx8tfEMeXl5xMfHN3oWtyzoSinuGRjBlrQch69EL4QQlTVGDy57uWVBBxjXLwwfTxNJ8nC0QRnVZCcun3xWDaMxenDZy20LepCfN7dGteN/WzM5V2h7lXRRP76+vpw8eVIKhQvQWnPy5El8fX2NjuJ2nnnuBZTnxSNC69uDq67camBRVZMGmVm6OYPPtmYyeZC59jeIyxIWFkZGRgbZ2TJtsSvw9fUlLEw6CThaaYchtBr1GD7JSziSmUFERAQJCQkN8kC0Nm5d0KPCAunRrgVJ61KZNDDCkI7+7szLy4vIyEijYwhhmFKr5sP1aYwcM44Pv/iH0XHct8kFyh6OTh5kZs/Rs2w4fMroOEIIN7N673Eyc/KZ5CQtAG5d0AHGRrcnsJkX769NMTqKEMLNLFyXSusAH27o3sboKEATKOjNvD2YMCCcFTuPkZUji18IIRzjUPY5Vu/NJnagGS8P5yilzpGigU0eZEZrLQONhBAO88HaVLw8ysa8OIsmUdDDWvpxfbc2LN6YLvO7CCHq7WxBMUs3Z3Br73aNv4jFJTSJgg4w5SoLp84X8cVvWUZHEUK4uE82Z3CusIT7rrIYHeUiTaagD+4YTJc2/ry/NkUGwggh6sxq1XywNpU+EUFEhQcZHecidhV0pdQopdRepdQBpdQzNl6PUEqtUkptVUptU0rd5Pio9aOU4t7BFnZknmFL2mmj4wghXNRP+7M5dOI8U5zs7hzsKOhKKQ/gbWA00B2YqJTqXmW3Z4ElWus+wASg7kuxN6Db+7QnwNeTd39JMTqKEMJFvfdrCqEBPox2wmUu7blDHwAc0Fof0loXAYuBsVX20UCL8q8DAadsqG7u48n4mHC+2XGUY2dkSTAhxOU5fOI8q/dmM2mgGW9P52uxtidReyC90vcZ5dsqex6YpJTKAJYDv7d1IKVUnFJqk1Jqk1Hzf9w72Eyp1iRJF0YhxGV6/9cUvDwUEweGGx3FJnsKuq0JUKo+VZwIvKe1DgNuAhYqpaodW2udqLWO0VrHhIaGXn5aBzAHN+faK1vz4YY0CkukC6MQwj7nCktYujmDW3q3o3WAc85aaU9BzwAq/zgKo3qTylRgCYDWei3gC4Q4ImBDuO8qCyfOFbF8+xGjowghXMSFrorO+DD0AnsK+kags1IqUinlTdlDz2VV9kkDrgNQSnWjrKA77ZyqQzuF0CG0Oe/+Il0YhRCXlpSUhNlsZsqQDhxPnMqOn74yOlKNai3oWusS4DFgBbCbst4sO5VSs5VSY8p3ewJ4UCn1G7AImKKduFKaTIrfDYlkW0Yum1KlC6MQwrYLC0CXLSenyT99jLi4OJKSkoyOZpMyqu7GxMToTZs2GXJugPyiUga/+D0DI1vx78k2F9AWQjRxFouF1NTqHSjMZjMpKSmNHwhQSm3WWtssWs7X76aRNPP2YNJAMyt3HSP15Hmj4wghnJAzLQBtjyZb0KGsC6OnSclAIyGETc60ALQ9mnRBb93ClzFR7VmyKZ3cvGKj4wghnMzTzz3vNAtA26NJF3SAqUMjySsqZdFG5/wVSghhnPzwwQSPfoz2YeEopTCbzSQmJhqyALQ9mnxB796uBUM6BfPeLykUl1qNjiOEcBL5RaUsXJfK7eMmkJGehtVqJSUlxWmLOUhBB8ru0o+eKZCBRkKICp9uzeB0XjEPDOtgdBS7SUEHrunSmg6hzXlnzWEZaCSEwGrV/GfNYXqHBdLf0tLoOHaTgk7ZQKOpQyPZnpnLxhQZaCREU7dq73EOnTjPA8M6oJSt6ayckxT0cnf0CaOlnxfvrDlkdBQhhMHeWXOYdoG+jO55hdFRLosU9HLNvD2YNMjMt7uPkXJCBhoJ0VTtyMxl7aGTTBliwcvDtUqka6VtYJMHm/EymfjPz4eNjiKEMMh/fj5Mc28PJgxwzsFDlyIFvZLWAb7c3qdsoNGJc4VGxxFCNLLMnHy++C2Lu/tH0MLXy+g4l00KehVxwztQVGrl/V9TjI4ihGhkF56hPTAs0uAkdSMFvYqOof6M7N6GD9amcr6wxOg4QohGcvp8EYs3pDMmuh3tgpoZHadOpKDbMG14R3Lzi1m0QaYDEKKpeH9tCvnFpUwb3tHoKHUmBd2GPhEtueL4Rh4dMxiTyYTFYnHaCe2FEPWXV1TC+7+mcH231nRpE2B0nDrzNDqAM0pKSiJ50csUFeQDkJqaSlxcHIBTz+MghKibjzamczqv2KXvzkHu0G2Kj4+nsLyYX5CXl0d8fLxBiYQQDaW41Mo7aw7T39KSGEsro+PUixR0G1xtlRIhRN19uS2LzJx8l787BynoNrnaKiVCiLrRWjNv9SG6tPFnxJWtjY5Tb1LQbUhISMDPz++ibb7NmjntKiVCiLpZtfc4e4+dZdrwjphMrjMJV02koNsQGxtLYmIiZrMZpRRega0ZOHmGPBAVws3MW32I9kHNuDWqndFRHEIKeg1iY2NJSUnBarXyytKfSWnZl91HzhgdSwjhIJtTT7Eh5RRTh0a63CRcNXGPq2hgU66y4O/jyVurDhgdRQjhIG/+cIBWzb2ZMCDc6CgOIwXdDoF+Xtw72Mzy7Uc4cPys0XGEEPW0LSOH1XuzmTo0Ej9v9xmOIwXdTlOHRuLr6cHcVQeNjiKEqKe3fjhAC19P7h1sNjqKQ0lBt1Owvw+xAyP4/LcsUk/KAhhCuKrdR86wctcx7h8SSYALTpF7KVLQL0Pc1R3wMCn+tVru0oVwVW+tOoC/jyf3D7EYHcXhpKBfhtYtfJnQP5xPtmSQmZNf+xuEEE7lwPGzLN9+hMmDzQT5eRsdx+GkoF+mh8qHB//7R7lLF8LVzF11EF9PDx4Y6poLWNTGroKulBqllNqrlDqglHqmhn3GK6V2KaV2KqU+dGxM59E+qBl39g1j8cZ0jp8pMDqOEMJOqSfP8/lvWcQOjCDY38foOA2i1oKulPIA3gZGA92BiUqp7lX26QzMAIZorXsAf2yArE7jkWs6UWrVzC9frkoI4fzmrjqIh0kRd3UHo6M0GHvu0AcAB7TWh7TWRcBiYGyVfR4E3tZanwbQWh93bEznEhHsx9iodvx3XRonZTFpIZxexuk8PtmSwYT+4bRu4Wt0nAZjT0FvD6RX+j6jfFtlXYAuSqlflFLrlFKjbB1IKRWnlNqklNqUnZ1dt8RO4pERnSgoKWX+msNGRxFC1GLejwdRCreYIvdS7CnotqYg01W+9wQ6A9cAE4F3lFJB1d6kdaLWOkZrHRMaGnq5WZ1Kp9b+jIlqxwdrUzghd+lCOK3MnHyWbMxgXL8wl1382V72FPQMoPJkB2FAlo19PtdaF2utDwN7KSvwbu3x6zpTUFwqPV6EcGJvrzqARvPoiE5GR2lw9hT0jUBnpVSkUsobmAAsq7LP/4ARAEqpEMqaYNz+iWHHUH9u69OeD9amSo8XIZxQ+qk8lmxMZ0L/CMJa+tX+BhdXa0HXWpcAjwErgN3AEq31TqXUbKXUmPLdVgAnlVK7gFXAdK31yYYK7Uz+cF1nSqyauTJ6VAin8+YP+zGZVJO4O4eytu9aaa2XA8urbJtZ6WsN/Ln8T5NiDm7OuL5hfLghjYeGd6BtoHu30QnhKlJOnOeTLZncO9jMFYHu27OlMhkp6gCPXdsJrTVvy3zpQjiNOd/vx8tD8fA17t2zpTIp6A4Q3sqP8THhfLQxnYzTeUbHEaLJO3D8HP9LzuTewRZaBzSNu3OQgu4wj47ohELx1g9yly6E0V7/bh++Xh485MajQm2Rgu4g7YKaMXFAOB9vzpD50oUw0J6jZ/hq+xHuH2Jx2zlbaiIF3YEeGdEJT5Nizvdyly6EUV7/dj/+3p48OKxp3Z2DFHSHatPCl0mDzHy2NUPWHhXCANsycvhm51F+NzTSLec7r40UdAd75JqO+Hl78sqKvUZHEaLJefmbvbRq7s0Dw9xzvvPaSEF3sGB/Hx4c1oEVO4+xJe200XGEaDJ+3n+Cnw+c4NERndxurVB7SUFvAA8MiyTE35uXvt5D2ZgrIURDslo1L32zh/ZBzZg0KMLoOIaRgt4Amvt48vtrO/P9l5/QNiwCk8mExWIhKSnJ6GhCuKXlO46wPTOXP9/QBR9PD6PjGMauof+iDg7+zOkVb2EtLptaNzU1lbi4OABiY2ONTCaEWykutfKPFXu5sk0At/WpulRD0yJ36A1k1nPPVhTzC/Ly8oiPjzcokRDuacmmdFJO5jH9xivxMNlavqHpkILeQNLS0i5ruxDi8uUXlfLGd/uJMbfkum6tjY5jOCnoDSQiwvaDmZq2CyEu34JfDnP8bCFPj+6KUk377hykoDeYhIQE/PwunlC/mZ8fCQkJBiUSwr2cPl/EvB8Pcl3X1vS3tDI6jlOQgt5AYmNjSUxMxGw2o5TCo0Uot//+BXkgKkQ9JSUlYbFYaBXgy+7XJtMtb5vRkZyGFPQGFBsbS0pKClarlWn/+obNXj04kptvdCwhXFZSUhJxcXGkpqaC1pSeyWbW9MelS3A5KeiN5OlRXbFqZEoAIeohPj6evLyL1xyQ3mP/nxT0RhLeyo/7h1j4dEsm2zNyjY4jhEuS3mOXJgW9ET06ohOtmnvzt692yZQAQtRBeHi4ze3Se6yMFPRG1MLXiz9d35n1h0+xctcxo+MI4XJue/BJlOfFi1b4Se+xClLQG9nEARF0au3Pi1/voajEanQcIVxGflEp603dib7nKSIiIlBKYTabSUxMlN5j5aSgNzJPDxN/uakr23/8knZh4TJxlxB2mr/mEEfPFDDv+T+SmpqK1WolJSVFinklMjmXAY5s/o6clW9TWlQAyMRdQtTm+JkC5v14kNE9r2BApAwiqoncoRsgPj6+ophfIF2vhKjZS9/spbjUytOjuhodxalJQTeAdL0Swn6bU0/zyZYMHhjWAUtIc6PjODUp6AaoqYuVyWSStnQhKrFaNc8v20mbFj48NqKT0XGcnhR0A9iauAugtLSUyZMn88gjjxiQSgjns2RTOtszc/nLTd1o7iOP/GojBd0AFybu8vCovlSW1pp58+bJnbpo8nLzinl5xV4GWFoxJqqd0XFcghR0g8TGxmK12u6HrrWWB6SiyXvtu33k5BUxa0x3mevcTlLQDXSp4crygFQ0ZXuPnmXhulTuGRhBj3aBRsdxGXYVdKXUKKXUXqXUAaXUM5fYb5xSSiulYhwX0X0lJCTUeOchc1OIpkrrsgehAb6ePHHDlUbHcSm1FnSllAfwNjAa6A5MVEp1t7FfAPA4sN7RId1VbGws06ZNq1bUfXybydwUosn6YtsR1h46yZMjr6Rlc2+j47gUe+7QBwAHtNaHtNZFwGJgrI39/gq8DBTYeE3UYO7cuSxcuLBiZSOfoDZEjP0jd9w1wehoQjSaC6sQmUwm7hreh1ZH1jNxgPyWernsKejtgfRK32eUb6uglOoDhGutv7zUgZRScUqpTUqpTdnZ2Zcd1l1VXtlo9eadFFmGMOeH/UbHEqJRVF6FSGtNUe5xdn/8DxYv+tDoaC7HnoJuq5G3YjJvpZQJeA14orYDaa0TtdYxWuuY0NBQ+1M2IYM6BDOuXxjzfzrE3qNnjY4jRIOztQpRQX6+9PSqA3sKegZQeVb5MCCr0vcBQE9gtVIqBRgELJMHo3X3l5u6EeDrSfxn27FaZSEM4d5kKgzHsaegbwQ6K6UilVLewARg2YUXtda5WusQrbVFa20B1gFjtNabGiRxE9CquTczburGptTTLNmUXvsbhHBhNfXokp5el6/Wgq61LgEeA1YAu4ElWuudSqnZSqkxDR2wqbqrXxgDIlvxf1/v4cS5QqPjCNFgnoyfhfKSVYgcwa5+6Frr5VrrLlrrjlrrhPJtM7XWy2zse43cndefUoq/396TvKIS/v7VbqPjCNFgtvv2os3Nj9M+LFxWIaonGSnqxDq1DuChqzvy6dZMfjlwwug4QjjcNzuOsmLnMZ77QxwZ6WmyClE9SUF3co9d24nIkOY88+k28opKjI4jhMPk5hXz3Oc76N62BQ8O62B0HLcgBd3J+Xp58NKdvUk/lc8rK/YaHUcIh/nrV7s4db6Il8f1xstDSpEjyN+iCxgQ2Yp7B5t579cUNqWcMjqOEPX2475slm7OYNrwDvRsL5NvOYoUdBfx1KiutAtsxlOfbKOguNToOELU2bnCEv7y6XY6hjbn99d2NjqOW5GC7iL8fTx58c5eHMo+zxvfy7QAwnW99PUesnLzeXlcFL5e1Rd5EXUnBd2FDOscyt0x4ST+dIhtGTlGxxHisq0/dJKF61KZcpWFfuaWRsdxO1LQXcxfbu5GiL83Ty3dRlGJ7RWPhHBG+UWlPPPpdsJbNWP6jTLPeUOQgu5iApt5kXBbL/YcPcscaXoRLuT/vt7N4RPneemO3vh5y4LPDUEKugu6vnsb7uoXxtzVB9icetroOELU6qd92XywNpX7h1i4qlOI0XHclhR0FzXz1u60C2rGn5ckc75QBhwJ55WbV8xTS7fRqbU/T4/qanQctyYF3UUF+Hrxz7uiSDuVR8JymetFOK/nPt/BiXOFvDY+Wnq1NDAp6C5sYIdg4oZ14MP1aazac9zoOEJU88VvWSz7LYvHr+tMrzAZQNTQpKC7uD+P7ELXKwKYvnQbp84XGR1HiArHzhTw7P92EB0exCPXdDQ6TpMgBd3F+Xh68Nrd0ZzJLyb+s+1oLSscCSiJeTUAABIRSURBVONZrZrpS7dRWFLKq+Oj8JS5WhqF/C27gW5tW/DEyC58veMoH22UFY6E8f7z82F+2pdN/M3d6RDqb3ScJkMKupt4cFgHhnYK4fkvdrLvmCwuLYyzLSOHl1fs4cYebZg0UJaRa0xS0N2EyaR49e4o/H08eezDLTKBlzDE2YJifr9oK6H+Prx0Z2+UUkZHalKkoLuR1gG+vDo+mn3HzjH7y11GxxFNjNaaZ/+3g/RTebwxsQ9Bft5GR2pypKC7mau7hDJteEc+XJ/GV9uOGB1HNCFLN2fweXIWf7y+C/0trYyO0yRJQXdDT4zsQnR4ENOef52w8AhMJhMWi4WkpCSjowk3dTD7HLOW7WRQh1Y8OqKT0XGaLJkhxw15eZi41msfy758A11cCEBqaipxcXEAsgCvcKi8ohIe/u9mfDxNvH53HzxM0m5uFLlDd1Ov/v2FimJ+QV5eHvHx8QYlEu5Ia80zn2xn//FzzJnYhysCfY2O1KRJQXdTaWlpl7VdiLr4YG0qy37L4okbujCsc6jRcZo8KehuKiLCdv/fmrYLcbk2p57mb1/t4rqurXnkGmk3dwZS0N1UQkICfn5+F20zefnw7KzZBiUS7uTEuUIeTdpC28BmvDo+GpO0mzsFKehuKjY2lsTERMxmM0oprmgfRsjo3/Or6obVKvO9iLorKbXy+KKtnM4r4l+T+hLo52V0JFFOCrobi42NJSUlBavVypGMdF555lF+2HOcN2TpOnGZkpKSsFgsmEwmQtqGsXLZUv52W096tJMpcZ2JFPQmZPIgM3f2DeON7/fLoCNht6SkJOLi4khNTUVrTW72Ec58+zZFe38yOpqoQgp6E6KU4u939CTG3JInPk5me0au0ZGEC4iPjycvL++ibcWFBdIF1gnZVdCVUqOUUnuVUgeUUs/YeP3PSqldSqltSqnvlVJmx0cVjuDj6cG8yf0Ibu7DAx9s5NiZAqMjCScnXWBdR60FXSnlAbwNjAa6AxOVUt2r7LYViNFa9waWAi87OqhwnBB/H965L4azBSXEfbBJZmYUlxQWHm5zu3SBdT723KEPAA5orQ9prYuAxcDYyjtorVdprS/8TrYOCHNsTOFo3dq24I0JfdiWmcuTH/8mKx0Jm0pKrXQY9QDK0+ei7X5+fiQkJBiUStTEnoLeHqi8DE5G+baaTAW+tvWCUipOKbVJKbUpOzvb/pSiQdzQvQ1Pj+rKl9uO8MqKvUbHEU5Ga82sZTtJadmXh+JfrOgCazabSUxMlDmBnJA9k3PZGjFg83ZOKTUJiAGG23pda50IJALExMTILaETeOjqDqSdymPu6oNcEejLvYMtRkcSTuJfPx4kaX0a04Z35JnRN8PzfzQ6kqiFPQU9A6jciBYGZFXdSSl1PRAPDNdaF1Z9XTgnpRSzx/Tg+JlCZi3bSesAX0b1vMLoWMJg/9uaycvf7GVMVDueuvFKo+MIO9nT5LIR6KyUilRKeQMTgGWVd1BK9QH+DYzRWh93fEzRkDw9TLw5sQ/R4UE8vngrf309sWIQicyj3vT8evAE05f+xqAOrXjlrt4yrN+F1FrQtdYlwGPACmA3sERrvVMpNVspNaZ8t1cAf+BjpVSyUmpZDYcTTqqZtwf/ua8/nod+YdZTj1cMIklNTWXy5Mk88sgjRkcUjWB7Ri5xH2wmMqQ5/54cg4+nh9GRxGVQRvVuiImJ0Zs2bTLk3KJmYeERZGakV9uulGLhwoXyIMyN7T92lvH/XouftydLHx5M28BmRkcSNiilNmutY2y9JiNFxUWyMjNsbtday8hAN5Z+Ko9J/1mPp4eJpAcGSjF3UVLQxUUuNVhERga6p2NnCoh9Zz2FJVb+O3UglpDmRkcSdSQFXVwkISEBpWw/BJORge7nxLlCJr2znpPnCnn//gFceUWA0ZFEPUhBFxeJjY1l2rRp1Yq6h5cP1984Snq/uJHss4VMTFxH+uk83rmvP1HhQUZHEvUkBV1UM3fuXBYuXFgxMrB12zD8el3He+++d1Hvl7i4OCnqLqTynObhERFcM+2vZJzO590pAxjcMdjoeMIBpJeLsEubduEcP1L9ganZbCYlJaXxA4nLcmFO88rT4CovH2a99Aaz/vSQgcnE5bpULxcp6MIuJpPJ5gReSimsVqsBicTlsFgspKamVtsuP5Bdj3RbFPVW0wNReVDqGmRO86ZBCrqwS0JCAn5+fhdtM3n58PjTzxmUSFyOK9rZntFafiC7Fynowi6xsbEkJiZWPCht2z6ciLF/4vXv9tM2LFx6vjixXw+ewDRgIiYvmdPc3UlBF3aLjY0lJSUFq9VKVkYaf7q+C5lfvM7RzAzp+eKkvt5+hCkLNtJt6E3MmfsvmdPczclDUVFn8qDNeWmtmffjIV76Zg99IoJYcF9/Wjb3NjqWcIBLPRS1Zz50IWySB23OqajEyrP/286STRnc0rst/7grCl8vmTWxKZAmF1FnNT1Q82vZhty84kZOIwBy8oq4d8F6lmzK4PFrOzFnQh8p5k2IFHRRZ7Z6vnj7NqP50EncPvcXDmafMyhZ07QjM5db3/qZLak5vHZ3FH8eeaUsTtHESEEXdVa154vZbGbBO/P54vUZ5OQXM/atX/h6+xGjYzYJSzdncOe/fqW4RLP4oUHc3sd2N0Xh3uShqGgQmTn5PJq0heT0HO4fYmHG6G54e8r9g6MVlpQy+4tdJK1PY3CHYN68pw8h/j61v1G4LHkoKhpd+6BmLHloMC9+vYcFvxxmS1oOb9/Th7CWfrW/WdjlwPFz/GHxVnZmnWHa8I48ObILnh7yQ7Mpk09fNBhvTxMzb+3Ov2L7cuj4OUa/sYY/JbyJ2WyWgUj1oLUmaX0qt7y5hqycfObfG8Mzo7tKMRdyhy4a3uhebenRLpBxT77MG/99EV1SCFAxEAmQAS52OnmukGc+3c63u44xrHMI/7writYtfI2OJZyEtKGLRmM2m232UZeBSLXTWvN5chazv9zFuYISnh7dlfuvskgvliZI2tCFU0hPT7e5XQYiXVpmTj7PfradVXuziQ4P4uVxvenSRpaKE9VJo5toNDUNRPJsEcp7vxym1GrMb4vOqqjEyjtrDjHy1R9Zd+gUM2/pzicPXyXFXNRICrpoNLYGIjVr1owBdz3C81/s4tY3f+bXAycMSudcVu05zqjXf+JvX+2mf2QrVv7pan43NBIPaWIRlyAFXTQaWwOR5s+fz5rEmbw5sQ+5+cXc8856pr63kQPHzxod1xB7j55lyrsbuP+9jaDg3Sn9ee/+AYS3ku6eonbyUFQ4jYLiUt77NYW3fzhAXnEp4/qG8ciIjpiDmxsdzaGSkpKIj48nLS2NiIgIEhISGHjDWF7/bh9fbT+Cv48nf7iuM/cOtshgLFGNrCkqXMqp80XM+X4/H25Io9SqGRvdjkdHdKJjqL/R0erN1mLNnt6+BN34KKHR13P/EAsPDutAkJ9MdStsk4IuXNLxMwX8+6dDJK1PpbDEysjubbjvKguDOwSjlGu2Jdc0h3xgaFsOHjpMsAzbF7WQgi5c2olzhfzn58Ms2pBGTl4xV7YJYPJgM7dGtSOwmZfR8ex28lwhIS2agY3/c0oprFarAamEq5GCLtxCQXEpy37L4v1fU9iZdQZvTxM3dGvDbX3aM7xLqFO2N+cXlfLjvmyW/ZbJt7uOcfitKZSeya62nwyuEvaSgUXCLfh6eTA+Jpy7+oWxLSOXz7Zm8sVvWXy1/QgBvp4M7xLK9d3aMLxLqKHLrZ08V8jPB06wYudRVu3JJr+4lFbNvbl3sAW/9v/HzOmPX9SGLos1C0exq6ArpUYBbwAewDta6xervO4DfAD0A04Cd2utUxwbVYgySimiwoOICg8i/uZurNmfzTc7jvLDnuN8ue0IJgW92gcSY2lFf0srYiwt7Z5SNikpiT/84Q+cPHkSgODgYN54440a55rRWnMkt4DtmblsOHyKXw6cYM/Rsi6XoQE+3NmvPaN7tmVgZKvyybO6c0Wgb7VeLjKXjXCEWptclFIewD7gBiAD2AhM1FrvqrTPI0BvrfU0pdQE4Hat9d2XOq40uQhHs1o12zJz+WH3MdYdPkVyeg5FJWXt0q0DfLjyigCubBNAh1B/2gb50i6wGVcE+tLC1xOlFElJSfzud7+jqKjoouN6eXnx5rz5XHfLHaSfziftZB4pJ89zKPs8OzJzOXm+bH9vTxP9LS25qmMIQzqF0Kt9oAwEEg5X3yaXAcABrfWh8oMtBsYCuyrtMxZ4vvzrpcBbSimljWqgF02SyaSIDg8iOjwIKFv8YUdmLltSc9hz9Cz7jp1l4bqyHjMXvU9Bc29P9s35c7ViDlBcXMyjf5pO2L6Qim3NvDywhDRnRNfW9A4LpGf7QLq3bSHrdwpD2VPQ2wOVZ1XKAAbWtI/WukQplQsEAxeN41ZKxQFxUPO8HkI4io+nB/3MrehnblWxrdSqOXamgCO5+WTlFHA0t4AzBcWcKyzh+dzqDysr3nf2BG9MiKZ9UDMigv0I9fdx2a6Twn3ZU9Bt/auteudtzz5orROBRChrcrHj3EI4lIdJ0S6oGe2CmtHPfPFr70ZE2OwjDmCOiGBsdPtGSChE3dnTzysDCK/0fRiQVdM+SilPIBA45YiAQjSWhIQEvL2r947x8vKSXijCJdhT0DcCnZVSkUopb2ACsKzKPsuA+8q/Hgf8IO3nwtXExsayYMECgoODK7YFBwfz7rvvSi8U4RLsGliklLoJeJ2ybosLtNYJSqnZwCat9TKllC+wEOhD2Z35hAsPUWsivVyEEOLy1XtgkdZ6ObC8yraZlb4uAO6qT0ghhBD143xjpYUQQtSJFHQhhHATUtCFEMJNSEEXQgg3Ydj0uUqpbMD2KI7ahVBlFGoTINfcNMg1Nw31uWaz1jrU1guGFfT6UEptqqnbjruSa24a5Jqbhoa6ZmlyEUIINyEFXQgh3ISrFvREowMYQK65aZBrbhoa5Jpdsg1dCCFEda56hy6EEKIKKehCCOEmnLqgK6VGKaX2KqUOKKWesfG6j1Lqo/LX1yulLI2f0rHsuOY/K6V2KaW2KaW+V0qZbR3HldR2zZX2G6eU0kopl+/iZs81K6XGl3/WO5VSHzZ2Rkez4992hFJqlVJqa/m/75uMyOkoSqkFSqnjSqkdNbyulFJzyv8+timl+tb7pFprp/xD2VS9B4EOgDfwG9C9yj6PAPPKv54AfGR07ka45hGAX/nXDzeFay7fLwD4CVgHxBiduxE+587AVqBl+fetjc7dCNecCDxc/nV3IMXo3PW85quBvsCOGl6/CfiashXfBgHr63tOZ75Dr1icWmtdBFxYnLqyscD75V8vBa5Trr3QY63XrLVepbXOK/92HWUrSLkyez5ngL8CLwMFjRmugdhzzQ8Cb2utTwNorY83ckZHs+eaNdCi/OtAqq+M5lK01j9x6ZXbxgIf6DLrgCClVNv6nNOZC7qtxamrLup40eLUwIXFqV2VPddc2VTKfsK7slqvWSnVBwjXWn/ZmMEakD2fcxegi1LqF6XUOqXUqEZL1zDsuebngUlKqQzK1l/4feNEM8zl/n+vlV0LXBjEYYtTuxC7r0cpNQmIAYY3aKKGd8lrVkqZgNeAKY0VqBHY8zl7Utbscg1lv4WtUUr11FrnNHC2hmLPNU8E3tNa/1MpNRhYWH7N1oaPZwiH1y9nvkNviotT23PNKKWuB+KBMVrrwkbK1lBqu+YAoCewWimVQllb4zIXfzBq77/tz7XWxVrrw8Beygq8q7LnmqcCSwC01msBX8omsXJXdv1/vxzOXNCb4uLUtV5zefPDvykr5q7ergq1XLPWOldrHaK1tmitLZQ9NxijtXblBWnt+bf9P8oegKOUCqGsCeaS6/Q6OXuuOQ24DkAp1Y2ygp7dqCkb1zLg3vLeLoOAXK31kXod0egnwbU8Jb4J2EfZ0/H48m2zKfsPDWUf+MfAAWAD0MHozI1wzd8Bx4Dk8j/LjM7c0NdcZd/VuHgvFzs/ZwW8CuwCtlO28LrhuRv4mrsDv1DWAyYZGGl05npe7yLgCFBM2d34VGAaMK3SZ/x2+d/Hdkf8u5ah/0II4SacuclFCCHEZZCCLoQQbkIKuhBCuAkp6EII4SakoAshhJuQgi6EEG5CCroQQriJ/wfZiC9quTHD/QAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"def f(x):\n", | |
" # quadratic\n", | |
" return ((x - .5)**2)/.25\n", | |
"\n", | |
"def df(x):\n", | |
" # derivative of quadratic is linear\n", | |
" return 2*(x - .5) / .25\n", | |
"\n", | |
"# sample size\n", | |
"n = 1000\n", | |
"# these will be my training data sets, with regular sampling\n", | |
"x = np.linspace(0,1,n)\n", | |
"y = f(x)\n", | |
"\n", | |
"# get test data too, but randomly sampled\n", | |
"x_test = np.random.ranf(20)\n", | |
"y_test = f(x_test)\n", | |
"\n", | |
"plt.plot(x,y, label='training points')\n", | |
"plt.plot(x_test,y_test, 'ko', label='test data')\n", | |
"plt.legend()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define and run tiny nonlinear regression NN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.callbacks.History at 0x1ee9574e988>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = Sequential()\n", | |
"# 1d input\n", | |
"model.add(Dense(64, input_dim=1, activation='relu'))\n", | |
"model.add(Activation(\"linear\"))\n", | |
"model.add(Dense(32, activation='relu'))\n", | |
"model.add(Activation(\"linear\"))\n", | |
"model.add(Dense(32, activation='relu'))\n", | |
"# 1d output\n", | |
"model.add(Dense(1))\n", | |
"\n", | |
"# minimize mse\n", | |
"model.compile(loss='mse', optimizer='adam', metrics=[\"accuracy\"])\n", | |
"\n", | |
"\n", | |
"model.fit(x, y,\n", | |
" batch_size=10,\n", | |
" epochs=25,\n", | |
" verbose=0,\n", | |
" validation_data=(x_test, y_test))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Plot performance (not bad, not perfect but really just an example model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.legend.Legend at 0x1ee96815f08>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"y_result = model.predict(x_test)\n", | |
"plt.plot(x,y, label='exact')\n", | |
"plt.plot(x_test, y_result, 'go', label='predictions')\n", | |
"plt.legend()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Derivatives\n", | |
"Extracting what the model thinks dy/dx is at a given point" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# compute the gradient of the output wrt the input\n", | |
"get_grads = K.gradients(model.output, model.input)[-1]\n", | |
"\n", | |
"# this function returns the grads given the input \n", | |
"function_grad_x = K.function([model.input], get_grads)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.legend.Legend at 0x1ee96a5dd48>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# gradients at training points\n", | |
"grads_x = function_grad_x([x.reshape(-1,1)])\n", | |
"plt.plot(x, df(x), label='dy/dx exact')\n", | |
"plt.plot(x, grads_x, 'g.', markersize=5, label='dy/dx Keras')\n", | |
"plt.title('Training Data')\n", | |
"plt.legend()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.legend.Legend at 0x1ee96ad7f08>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# gradients at testing points\n", | |
"grads_xtest = function_grad_x([x_test.reshape(-1,1)])\n", | |
"plt.plot(x_test, df(x_test), label='dy/dx exact')\n", | |
"plt.plot(x_test, grads_xtest, 'go', label='dy/dx Keras')\n", | |
"plt.title('Testing Data')\n", | |
"plt.legend()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Comments\n", | |
"Keras has estimated the gradient. Can the estimated gradient be improved? Why is gradient not smooth? I briefly investigated this.. seems Keras is using a \"computation graph\" within the *K.function* command to map *model.inputs* to *grads*. Alternatively, one may be able to train their own model to map *model.inputs* to *grads* with better accuracy. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
My solution to the example proposed by thearn. Note that I modified Define model & data section so that x need not be scaled.
Original example: https://gist.github.com/thearn/be9f98c5c2a6f87068490808f374a07c