Skip to content

Instantly share code, notes, and snippets.

@arbennett
Created November 9, 2021 22:18
Show Gist options
  • Save arbennett/44c8a75490206987becdd9943d71a5c9 to your computer and use it in GitHub Desktop.
Save arbennett/44c8a75490206987becdd9943d71a5c9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "a78ff981-3580-4741-93d2-592d4b47bac6",
"metadata": {},
"source": [
"# Putting numerics into a neural network\n",
"\n",
"This notebook highlights some quick experimentation I did with putting arbitrary numerical solvers into neural networks, particularly for encoding solutions to ODEs and PDEs as a network structure. Note that this is different than a [neural ODE](https://arxiv.org/pdf/1806.07366.pdf) in that a neural ODE parameterizes the entire network as an ODE while here we want to specify a specific ODE and learn some parameters of it through the optimization.\n",
"\n",
"To get started I'll be using the `torch.autograd` package, particularly the `jacobian` function. This does exactly what you think it does - it takes the Jacobian of a function which is autodifferentiable (aka implemented in pure pytorch). Let's import some packages and start messing around."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "994c99b1-91f3-4bb4-a48b-b612ba35e0f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"import xarray as xr\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from collections import OrderedDict\n",
"from typing import Optional, Tuple\n",
"from functools import partial\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"from torch.autograd.functional import jacobian\n",
"dtype = torch.float32\n",
"\n",
"# V will be our domain for testing functions\n",
"v = torch.tensor(np.arange(-2.0, 2.0, step=0.1), dtype=dtype)"
]
},
{
"cell_type": "markdown",
"id": "0232885b",
"metadata": {},
"source": [
"## Testing the `jacobian` on simple functions\n",
"Just to get a feel for how the `jacobian` function works, let's look at some examples where we have analytic solutions. Here I show that the autodiff calculation of the derivatives of both ReLU and hyperbolic tangent are equivalent to their analytic counterparts.\n",
"\n",
"### ReLU"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "05958579",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" This is separate from the ipykernel package so we can avoid doing imports until\n"
]
}
],
"source": [
"def drelu(x):\n",
" # Derivative of the relu function\n",
" return torch.tensor(x>0, dtype=torch.float32)\n",
"\n",
"torch_drelu_v = torch.hstack([jacobian(F.relu, vv) for vv in v])\n",
"truth_drelu_v = drelu(v)\n",
"\n",
"assert np.allclose(torch_drelu_v.numpy(), truth_drelu_v.numpy())"
]
},
{
"cell_type": "markdown",
"id": "5d23813a",
"metadata": {},
"source": [
"### Hyperbolic tangent"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "56d56c38",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/torch/nn/functional.py:1794: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
" warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
]
}
],
"source": [
"def dtanh(x):\n",
" return torch.tensor(1/np.power(np.cosh(x.numpy()), 2.0), dtype=dtype)\n",
"\n",
"truth_dtanh_v = dtanh(v)\n",
"torch_dtanh_v = torch.hstack([jacobian(F.tanh, vv) for vv in v])\n",
"assert np.allclose(truth_dtanh_v, torch_dtanh_v)"
]
},
{
"cell_type": "markdown",
"id": "1d9089b4-b5e7-4635-bfe9-910b567bd22d",
"metadata": {},
"source": [
"## Testing the jacobian on a simple NN layer\n",
"Great, those both work out of the box. Now can we take derivatives of simple neural networks? Let's find out. Here I'll define a basic feedforward type network, AKA a Multi Layer Perceptron. Rather than initializing weights randomly from a distribution I'm going to specify the values. That way we know what the derivative should be. I also note that our activation functions will be the ReLU and hyperbolic tangents that we know the `jacobian` function works on, but you could expand this.\n",
"\n",
"For the first test I'll just do the derivative of a layer with a single hyperbolic tangent activation. And for the second I'll do a two layer (each with single neurons) network of a ReLU activation followed by a hyperbolic tangent activation. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3a09f68e",
"metadata": {},
"outputs": [],
"source": [
"class MLPLayer(nn.Module):\n",
" \n",
" def __init__(self, width, activation=nn.ReLU, init=True):\n",
" super().__init__()\n",
" self.width = width\n",
" self.layer = nn.Linear(width, width)\n",
" self.activation = activation()\n",
" if isinstance(init, tuple):\n",
" self.init_parameters(*init)\n",
" elif init:\n",
" self.init_parameters(1, 0)\n",
" \n",
" def forward(self, X):\n",
" X = self.layer(X)\n",
" X = self.activation(X)\n",
" return X\n",
" \n",
" def init_parameters(self, weight_value, bias_value=None):\n",
" if bias_value is None:\n",
" bias_value = weight_value\n",
" with torch.no_grad():\n",
" self.layer.weight[:, :] = weight_value\n",
" self.layer.bias[:] = bias_value"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "588464d5-5c0b-4581-a46e-b2b67ebd9e03",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 1.0, 'Single neuron with tanh activation')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"m = MLPLayer(1, activation=nn.Tanh)\n",
"torch_dmodel_v = np.hstack([\n",
" jacobian(m, torch.tensor([[vv]])).detach().numpy().flatten() for vv in v])\n",
"\n",
"truth_dtanh_v = dtanh(v)\n",
"\n",
"plt.plot(v, truth_dtanh_v, label='Analytic')\n",
"plt.plot(v, torch_dmodel_v, linestyle='--', label='Autograd')\n",
"plt.xlabel('Input (x)')\n",
"plt.ylabel(r'Derivative ($ \\frac{\\partial M}{\\partial x}$)')\n",
"plt.legend()\n",
"plt.title('Single neuron with tanh activation')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c8c47c95-1f6f-4a3d-b74b-5741949b615a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" This is separate from the ipykernel package so we can avoid doing imports until\n"
]
},
{
"data": {
"text/plain": [
"Text(0.5, 1.0, '2 layer (relu, tanh)')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"m2 = nn.Sequential(MLPLayer(1, nn.ReLU),\n",
" MLPLayer(1, nn.Tanh))\n",
"torch_dmodel2_v = np.hstack([\n",
" jacobian(m2, torch.tensor([[vv]])).detach().numpy().flatten() for vv in v])\n",
"\n",
"truth_dmodel2_v = drelu(v) * dtanh(v)\n",
"\n",
"plt.plot(v, truth_dmodel2_v, label='Analytic')\n",
"plt.plot(v, torch_dmodel2_v, linestyle='--', label='Autograd')\n",
"plt.legend()\n",
"plt.xlabel('Input (x)')\n",
"plt.ylabel(r'Derivative ($ \\frac{\\partial M}{\\partial x}$)')\n",
"plt.title('2 layer (relu, tanh)')"
]
},
{
"cell_type": "markdown",
"id": "79ffacf8-2e1b-4ff9-b4e4-1c1176d4770e",
"metadata": {},
"source": [
"## Onto the numerical solvers\n",
"\n",
"From the above, you should be getting more confident that we can autodiff through neural networks, but that's not entirely our goal here. Remember, we're trying to combine an ODE solver with a neural network where we will represent an unknown portion of an ODE as it's own neural network. To do that, we'll need to look at some simple numerical solutions to ODEs.\n",
"\n",
"To get started let's implement some [root finding algorithms](https://en.wikipedia.org/wiki/Root-finding_algorithms), which will be used to solve our ODE. Without any exposition, here's a Newton solver. Note that it takes an `fprime` argument. This is the derivative of the function which we are trying to find a root for. Now you may see where the `jacobian` method will fit into the ODE solver..."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "75125582",
"metadata": {},
"outputs": [],
"source": [
"def newton_solve(f, fprime, x, tol=1e-5, max_iter=100, return_seq=False):\n",
" f_test = f(x)\n",
" x_list = [x]\n",
" it = 0\n",
" while torch.abs(torch.tensor(f_test)) > tol and it < max_iter:\n",
" x = x - (f_test / fprime(x))\n",
" f_test = f(x)\n",
" x_list.append(x)\n",
" it += 1\n",
" if return_seq:\n",
" return x_list\n",
" else:\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "aa836684",
"metadata": {},
"source": [
"## Testing the solvers on a parabola\n",
"For confidence, let's do an easy one, solving for the minimum of a parabola. No need for `jacobian`, neural networks, or anything fancy, we'll write it out from scratch. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4684a68d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(-1.1, 1.1, -0.0499959945678711, 1.0499998092651368)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqwUlEQVR4nO3deXyU9bX48c+ZJCSEAIaAyiZrQhI2A8omqIgLaMEFEK4KxaXgWku196r1quj9Wa3Xq0VrpW6IUkGtFqUKassiamWVxRBQLC4syhoIWchyfn98n2jEBALJzDPLeb9e/sFkMs/JmJznzHc5X1FVjDHGhEbA7wCMMSaWWNI1xpgQsqRrjDEhZEnXGGNCyJKuMcaEkCVdY4wJIUu6xhgTQpZ0jTEmhCzpGmNMCFnSNcaYEIr3O4BIsGLFiuPj4+OfBrphNypjalIBrCsrK7umd+/e3/kdTLiypFsL8fHxT5944olZLVq02BMIBKxZhTHVqKiokB07dmRv3779aWCE3/GEK6vaaqdbixYt9lnCNaZmgUBAW7RokY/7RGhqYEm3dgKWcI05Mu/vxPLKYQTlzRGRTBF5REQaBOP1jTEmmETkLhE5PRivHaw7UkfgV8DPgvT6Yami5KBsGX5D+pbhN6RXlByUmh47FsnJyTlHes6YMWParVixIgngtttuO7Hq13JycjLr4xrGRDsR6QhMAc4IxusHK+nOB7YA1wTp9cNORclB2XrxzZ2LV61PKV61PmXrJTd3Lt9XEPjRYxff3LkuifdIZs+e/WXv3r2LAaZOndqy6tdWrVqVF6zrGhNlrsKtxJgejBcPStJV1XLgOeA8EWkbjGuEm22jJncuWbMhhZLSACWlgZLVG1K+7HFJjx89tmZDyrZRkzvX5Tpz585t3KdPny5Dhw7t2KFDh64jRozoUFFRAUCfPn26LF68OPn6669vXVJSEsjMzMweMWJEB/ihis3Pzw/0798/Izs7OysjIyP7xRdfPO5I1/zNb37TskOHDl0HDBiQPnz48A533XXXCVWvB7Bt27b41q1bdwcoKytj0qRJbbp165aVkZGR/dBDDzUH+PLLLxNOOeWULpmZmdnp6eld582bl1JWVsbIkSPbp6end83IyMieMmXK8XV5f4ypCxGJByYA81X162BcI5hLxp4D7sT9APcF8TrhqaQ0oCWlQXnp9evXN/zkk0++aN++fWnv3r0z33333ZTzzjuvoPLrTzzxxJbp06cfn5eXl3vo9yYnJ1f8/e9//7xZs2YV27Zti+/bt2/mZZddtjcQqP7+u3jx4uQ333wzde3atbmlpaVy8sknZ+fk5BQeLr5HH320edOmTcvXrVu3vqioSE499dTM4cOH73vppZdShwwZkv/ggw9uLysrY//+/YGPPvooedu2bQmfffbZpwA7d+6Mq+PbY0xdnAe0Bn4ZrAsEbZZRVb8A/gFcJSJRP5vZ8tVHPk/s0aWAxISKap+QmFCR2LNLQctXH/m8rtfq3r37gU6dOpXGxcXRtWvXwk2bNtV6wrKiokJ+9atftcnIyMgePHhwxnfffdfgm2++qfHmu3DhwpRhw4btTUlJ0dTU1Ipzzjln75Gu8d577zV5+eWX0zIzM7NzcnKy9uzZE5+bm5vUr1+/Ay+99FLzX//6162WLl3aMDU1tSIzM7Pk66+/Tvz5z3/e9tVXX22SmppaXtufxZgguBrYAcwN1gWCnQyfAdoDZwX5Or4LJDbQlrMe2iTx8dUuLZP4eG016383BRIb1HnpWWJi4vevERcXR1lZWa3HiadNm9Zs165d8WvXrl2fl5eXm5aWVlpUVFTj78HhDi6Nj4/X8nKXIwsLC6XK98jDDz/8VV5eXm5eXl7uli1b1l5yySX7hg0bVrB48eINrVu3PjhhwoQOjz/+eFqLFi3K161blzt48OD9TzzxxPFjx45tX9ufxZj6JCInAsOB51X1YLCuE+yk+zqwhxiYUKsoOSjbxv6mk9aQALWsTLaOvbVTMCfSqoqPj9eSkpKfXCs/Pz+uefPmpYmJifrmm2823rp162Gr5DPPPLNg/vz5TQsLCyU/Pz/w3nvvHVf5tbZt25YsXbq0EcDMmTNTKx8/55xz8v/0pz+1qLz+mjVrEvft2xfYuHFjg9atW5fecsstO6+44oqdK1euTN62bVt8eXk5EyZM2Ps///M/W9auXZtcb2+CMUdnPG7I9ZlgXiSo24BVtVhEXgQmiUiaqu4K5vX89KOJtOp4k2vbRk3u3PrNP34W7Hguv/zyHVlZWdndunUrfOONN/5d+fg111yze9iwYZ27deuW1bVr18IOHToUH+51zjjjjMKhQ4fmZ2dnd23dunVJjx49DjRt2rQc4Lbbbvt2zJgxHWfNmpU2aNCgfZXfM3ny5J2bN29O7N69e5aqSrNmzUrfeuutTfPnz288derUE+Pj4zU5Obl85syZ/968eXPC1Vdf3b6iokIA7r333m+C9Z4YUxMREdzQwgeqGtSVPnK4j4/1cgGRHsBqYLKqPhrUiwXJ6tWrN/fs2XPn4Z6zZfgN6cWr1v+QdBMTKiQ+XrWsTKo+lpSTVRCKpFuf8vPzA02bNq3Yv39/oH///l2efPLJLwcOHHjYyTQTu1avXt28Z8+e7f2O42h4GyEWAVep6nPBvFbQJ7hUdQ3wMTDRu5tEpR9NpHmTZu3WvLbmR4/1qJ+JtFC74oor2mVmZmb36NEja/jw4Xss4ZooNBHIB2YH+0JBr3QBRORq4GlgoKp+EPQL1rPaVLrgjet663BbvvrI54HEBlrdY8GO1xg/RVqlKyLNgK3AM6p6Q7CvF6rWjrOBR3B3k4hLurUVSGyghw4dVPeYMSasjAMSgT+H4mIhWT+rqgXATOBSEUk90vONMSYUvCHPXwBLVXV1KK4Zyk0LfwaSgMtDeE1jjDmc/kBXQlTlQgiTrqquAlYQ5RNqxpiI8guggBBMoFUK9fbcPwPdgb4hvm5IFJcVy+nPnZ5++nOnpxeXFUtNjx0LEel90UUXdaj8d2lpKampqT0HDx7cGWDmzJlN77jjjhNrfgXYvHlzwtChQzseawx1tXXr1vgePXpkZmVlZc+bNy8lWNcZOXJk++eee+4nw1hVG/QcralTp6YFAoHeH3/8ccPKx9LT07tu2LChTj2jp06dmrZ58+aEurzGsYr1Vp4ichwwBpjpDYGGRKiT7ku4u8rEEF836IrLiuWs58/qvHTL0pSlW5amDJkxpHN+cX6g6mNnPX9W52NNvA0bNqzYsGFDw4KCAgF4/fXXm5xwwgnfd9S5/PLL8++///7th3uN9u3bl86bN++LY7l+fZg7d27jzp07F69fvz536NChIfslry8nnHDCwXvvvbflkZ9Zey+++GLzr776ypekWx/Kysr8DqEurgAaEsKhBQhx0lXV/bgJtbHRNqF27gvndl65bWVKSXlJoKS8JLBi64qUNo+06VH1sZXbVqac+8K5x9zacciQIfmvvPLKcQAvvfRSs5EjR+6u/NrUqVPTxo8ffxK4Sm/ChAltc3JyMtu0adO9surbsGFDg/T09K6Vzz/77LM7nXXWWZ1bt27d/f77729xzz33nJCVlZXds2fPzG+//TYOam7fWNvvr/Thhx82vPvuu9ssWLCgaWZmZnZBQYFMmzatWUZGRnZ6enrX6667rnXlc6tWYM8991zqyJEj2x/u56qoqGD8+PEnderUqeuZZ57ZeefOnTWuypk+fXpaTk5OZnp6etcFCxYkl5eX065du25bt26NBygvL+ekk07qtm3btp+8xpAhQ/I3btzYcPXq1YmHfu21115rcvLJJ2dmZ2dnDRs2rGN+fn5gwYIFyeeee24ngBdffPG4pKSkXsXFxVJYWCiV8a9bty55/PjxHSvfkzlz5jTOysrKzsjIyB49enT7oqIiAWjdunX3yZMnt6psyblq1aqkQ2NYvnx5Uvfu3bMyMzOzMzIysteuXZsIcM8995yQnp7eNT09veu99977k9aZF1xwQcfZs2c3rfz3yJEj20+fPv24mlp0zp07t3Hfvn0zhg8f3qFLly5da3qvw5k3xHktsExVV4by2n50/5qGu7uM8+HaIVNSXhIoOFgQV1JeUm/v8bhx43bPnj07tbCwUNavX5/cv3//AzU999tvv01Yvnx53pw5cz67++67W1f3nI0bNzb861//+sWyZcvW/+53v2udnJxcsX79+txTTjnlwLRp09KOFM/RfP+AAQOKbr/99q3Dhw/fk5eXl7tz5874e+65p/XChQs35ubmfrpq1apGL7zwwnFHumZ1P9cLL7xw3Oeff564YcOGT6dPn/7lypUraxy6KCwsDKxatSpv6tSpX06cOLFDXFwco0aN2vX00083A5gzZ06TrKysopYtW/6khAsEAtx8883bp0yZ8qNqd9u2bfH3339/y8WLF2/Mzc1d36tXr8L77rvvhIEDBxZ++umnyQCLFy9O6dy5c9HixYuTFyxY0CgnJ6fgyiuv3NOtW7fCGTNmfJGXl5cbCASYNGlSh9mzZ2/auHFjbllZGQ899FCLyus0b968LDc3d/1VV12144EHHjjh0Pgee+yxFtdff/23eXl5uWvWrFnfoUOHg++//37yX/7yl7QVK1asX758+foZM2a0+OCDDxpW/b4xY8bsnj17dipAcXGxfPDBB01GjRqVX7VF5+rVq9c///zzLfLy8hoArFmzptFDDz20ZdOmTZ8e6f9ZmDoNN4H2ZKgvHPKk602ofQxcG00Tau+Me+fzXi17FSTGJVbb2jExLrGid6veBe+Me+eYd6T17du36Jtvvkl86qmnmp199tn5h3vuiBEj9sbFxdG7d+/iXbt2VfvxdcCAAftTU1MrWrVqVZaSklI+evTovQDdu3cv3Lx580+qufr8/iVLljTq16/f/latWpUlJCQwZsyY3YsWLTriOG91P9eiRYsaX3rppbvj4+Np3759af/+/ffX9P2XXXbZboBhw4YVFBQUBHbu3Bl33XXX7Zw1a1YawLPPPtt8woQJNW6EmTRp0q6VK1emVCYfgIULFzbatGlTUp8+fTIzMzOzZ82alfbVV181SEhIoF27dsUrV65MWrlyZaObbrrp2wULFjRetGhR49NOO+0nwyurV69OatOmTUmPHj1KACZMmLBryZIljavEvgegT58+hV9//fVP3t/+/fsfePjhh1v+9re/PfGzzz5rkJKSogsXLkw5//zz9zZp0qSiadOmFRdccMGeBQsWNK76faNGjcr/8MMPmxQVFcmrr77atE+fPvtTUlK0phadAD169DiQmZkZtE5cIXAtsI8QTqBV8qvP7ZNAFjDIp+vXu6T4JH378rc3JcQlVLvjLCEuQd++/O1NSfFJddqRNnTo0L1333132/Hjx+8+3POSkn64Tk27Dhs0+GF3XCAQ+P57AoHA9+0ia2rfWNvvr8nhdkJWvRdXfrw+0s9V2/v3oc8TETp37lzavHnzsjfeeKPxqlWrGo0ePbrGG1pCQgI33njj9nvvvff7SUtVZeDAgfsqW1lu2rTp05dffvlLgAEDBhS88cYbTRMSEnT48OH7Pvroo5SPPvooZciQIT+5MRxpd2jlzx4fH6/Vvb/XXnvt7jlz5nzesGHDimHDhmW88cYbjWuz4zQ5OVn79eu3/7XXXmsye/bs1LFjx+724qm2Raf3PdX3jY4AItIcGA3MUNUaPy0Gi19J92VgL+5uExWKy4pl2MxhnUrLS6v96y8tL5VhM4d1qssKBoDrrrtu5y233LK1T58+RXV5ndqqqX1jXZ1++ukHPv7448bbtm2LLysr45VXXml25plnFgCkpaWVrly5Mqm8vJw5c+Yc8ZpnnHHG/ldeeaVZWVkZX375ZcK//vWvxjU996WXXkoFmD9/fkrjxo3L09LSygGuuuqqHddcc02HESNG7I6PP/xGzRtvvHHXkiVLmuzevTse4MwzzzywfPnylHXr1iUC7N+/P7BmzZpE72sF06ZNO/7UU08taNWqVdmePXviv/jii6TKs+xSUlLK8/Pz4wBOPvnk4i1btjSofJ0ZM2akDRo0qMaq/VC5ubkNsrKySu68887vzj333L2ffPJJw7POOqvgrbfeOm7//v2Bffv2Bd56663UwYMH/+Q1x44du3v69OnNly1b1rgysdbUorO28YSxCUAD3FBnyPnyBqpqITADGCkiLY70/EhQdSKtuq9XTq7VZSINoFOnTqX//d///V1dXuNo3Hbbbd8+88wzLXJycjIPN0F1tNq1a1d61113bTnjjDMysrKyuvbo0aPwiiuu2AswZcqULRdeeGHn/v37d6m6QqMm48aN29uxY8eSLl26dL366qtP6tOnT42JKjU1tTwnJyfzxhtvbDdt2rTNlY//x3/8R35hYWHcxIkTj9h+NCkpSSdOnPhdZdJt1apV2bRp0zaPHTu2Y0ZGRnbv3r0z165dmwQu6e7atSuh8oaSnZ1d1KVLl6LK45HGjx+/86abbmqXmZmZXVFRwZNPPrl59OjRnTIyMrIDgQC33nrrjiPFU+mFF15olpGR0TUzMzP7s88+S5o0adKugQMHFl522WW7evXqldW7d++scePG7TjttNN+csO++OKL9y1btqzxwIED91VW1JMnT96ZmZlZ3L1796z09PSuv/jFL9qVllZfVEQK7xSbicASVV3nSwyhaHhT7YVFsoFPgdtU9UFfgqil2jS8Of2509OXbln6fdJNjEusSIhL0NLyUqn6WJ/WfQoWX7nYejGEmcWLFydPnjy57YoVKzb4HUukC+eGNyJyNvAuME5VX/QjBt8+KqhqLrAQ1+A84g8jrDqRVjlp9s3kb9ZUfaxXy151mkgzwXHHHXecOHbs2E7333//Fr9jMUF3PbATeNWvAHyrdAFEZDRufPcCVX3Lt0COoLatHYvLiqVy+OCdce98nhSfpNU9Fux4jfFTuFa6ItIG+BJ4SFVv8yuOULV2rMnfgO24u0/YJl2goqKiQgKBwGETZlJ8kh46dFDdY8ZEK+/YpXBd2TAREHyaQKvk60ykqpYCTwHni0iHIz3fR+t27NjRtPIcL2PMT1VUVMiOHTuaAr5MUB2OiDTANbd5S1X/faTnB5PflS64fc93AJMA30r+wykrK7tm+/btT2/fvr0bPt+ojAljFcC6srKycDz9+yLgROAJn+Pwd0z3+yBEXsNtlGirqoc9ndYYY46WiCwETgLSVbXcz1jCpWp7AqjcJWKMMfVGRLoCZwBP+p1wIXwq3QCQC+xV1X5+x2OMiR4i8gRwFdBGVY+4CinYwqLSVdUK4I9AXxE51e94jDHRQUSaAuOBl8Ih4UKYJF3P87gG5zf6HYgxJmpMABoBj/kcx/fCYnihkog8jlvW0VZVQ9ZfwBgTfbxhyw3ADlUd4Hc8lcKp0gV4HNf9JxyXnBhjIsu5QGfCqMqFMKt0AUTkXSAT6KCqEX0AkzHGPyIyFzgFOElVw6bherhVuuCq3TbAhX4HYoyJTCLSCTgfmBZOCRfCs9KNAz4HvlLVM/yOxxgTeUTkEdykfDtV3ep3PFWFXaXrLV5+HDhdRE72ORxjTIQRkca4dbkvh1vChTBMup5ngULgZr8DMcZEnAlAE2Cqz3FUK+yGFyp5u0iuxpaPGWNqyVsmlgfsDtfdreFa6YK7SzXA9cA0xpjaGAqkA3/wO5CahG2lCyAi84AeQPtwm4E0xoQfEZkPdMPljCMeauqHcK50wVW7LYFRfgdijAlvIpKF2xDxp3BNuBD+lW4AWA/sA/poOAdrjPGViDyJm0Rrq6q1Pro+1MK60vW6jz2K21Vymr/RGGPClYg0x3UTeyGcEy6EedL1zAD2AJP9DsQYE7YmAQ1xRVpYC/ukq6oHgCeBi0Wko9/xGGPCi4gk4nafzVfVT/2O50jCPul6/giUA7/0OxBjTNgZgzt08hG/A6mNsJ5Iq0pEXsCd6NlGVfN9DscYEwZERICVuDX93SJhsj1SKl1wd7EUXJNzY4wBd+DkycAjkZBwIYIqXfj+GOWOQKdwXodnjAkNEXkT6IvrJlbkdzy1EUmVLsD/Am2xo9qNiXkikg38DHg8UhIuRF6lGwA+BYqBXpHyccIYU/9E5GngctxmiLA46bc2IqrS9TZLPIwbwznL32iMMX4RkROBccBzkZRwIcKSrudF4FvgVr8DMcb45kYgAfg/vwM5WhGXdFW1GHe651AR6eZ3PMaY0BKRFOB64HVV/dzveI5WxCVdz5O4kyWs2jUm9lwJpOIm1iNORE2kVSUij+H2W3dU1W/8jscYE3wikgB8BmxR1YhsghWplS64CbUA1gjHmFhyKdAOeNDvQI5VxFa6ACIyExgBnKSqe/yOxxgTPN6W30+AeKC7t5op4kRypQvwe9zW4Ov9DsQYE3RDccd3PRSpCRcivNIFEJG3gd5E0DZAY8zR89oAdMbN40TsmYmRXumCG9tpAfzc70CMMcEhIn1xzW0eieSEC9FR6QrwES7xdlHVMp9DMsbUMxF5DRiMm7/Z73c8dRHxla7Xf+EBXPexS30OxxhTz7zGNhfjGttEdMKFKKh04ftGOGuBCqBnJA+yG2N+TESeB0bh5m0iqs9CdSK+0oXvG+H8DugGXOBzOMaYeiIi7XGdxP4cDQkXoqTSBRCReNxOlW+B/tb20ZjIJyJ/xJ0WEzU7T6Oi0gXwJtAexHWRP9PfaIwxdeW1b7waeD5aEi5EUdL1TAe2A7/1OQ5jTN1NxrVv/L3fgdSnqEq6XtvHh4EhItLP73iMMcdGRNKAG4CXVfUzv+OpT1GVdD1PAruA//Y7EGPMMfsV0Aj4fz7HUe+iLumqagGum/z5ItLb73iMMUdHRI4Dfgm8pqrrfA6n3kVd0vU8DuwF7vQ5DmPM0bsJaALc53cgwRCVSVdV9wF/AC4SkR5+x2OMqR0RaYwbWnhTVT/xN5rgiMqk6/kDsB+rdo2JJNcDzYjSKheiaHNEdUTkfuA2XMPjT/2OxxhTMxFpBPwbWKmqQ/2OJ1iiudIFN6F2ALjL70CMMUd0A65b4BS/AwmmqK504UfVbo9onAk1Jhp4x6r/G1gRzVUuRH+lC26zxAFs3a4x4ewGoDlwj89xBF3UV7oAIvL/gNuxateYsONVuZuBZao6zOdwgi4WKl1wY7sF2NiuMeHoRiCNGKhyIUaSrqruAqYCo23drjHhQ0SaALcC81T1Y7/jCYWYSLqeh4F8onxm1JgIczOuyo2ZOZeYSbqqugc3zHCRiJzidzzGxDoRaYarcueo6nK/4wmVmEm6nkeB3UTxbhdjIsgtuB4LMTXXElNJ1+vJ8CAwVEQG+h2PMbFKRI7HDS3MVtU1fscTSjGxZKwqb6vhJiAPGGxnqRkTeiLyMK6xTVdVzfM5nJCKqUoXQFUPAPcDZwBn+xyOMTFHRNrgGtu8GGsJF2Kw0gUQkURgI/Ad0MeqXWNCR0T+DEwAMlR1s7/RhF7MVboAqloC3A2cAlziczjGxAwRyQCuAp6MxYQLMVrpAohIHLAWd+Pp5h3hbowJIhGZDVwAdFLVb/2Oxw8xWekCqGo57qj2LsB4n8MxJuqJSC/gUuCRWE24EMOVLoCICPAvoBWQ7h3hbowJAhGZB5wKdFTVfL/j8UvMVroA3gTa7UAbXNMNY0wQiMhZwHnAA7GccCHGK91KIvI20Bc3zrTH73iMiSYiEgCWAsfjVizE9CfKmK50q/gv4Dhc1WuMqV+XAr2BO2M94YJVut8TkenAWNyd+CufwzEmKnhr4tfjTubu5U1gxzSrdH9Q2XTDmuEYU3+uBToA/2UJ17FKtwoReRD4De6O/InP4RgT0UTkOOBz4BPgHNv56Vil+2O/A/YA/+stJzPGHLs7gGbAf1rC/YEl3SpUdS/uZIkhwPn+RmNM5BKRDrjWjTNUdaXf8YQTG144hIgkAOuACtzpwaU+h2RMxBGRWcAI3KajLX7HE06s0j2El2T/E8gEfuFzOMZEHBHpD4wBHrKE+1NW6VbDG8/9J9AN6BzrO2iMqS3vb+dDoB1u+WWBzyGFHat0q+EN+t9CjJ1Sakw9GAP0w22EsIRbDat0D0NEngHG4Vo/bvQ7HmPCmYgkAxv44XAAW5dbDat0D++3QDHwsN+BGBMB/hPXPOpmS7g1s6R7GKq6HbdD7Wcicp7f8RgTrkTkJFwPk1mqusTveMKZDS8cgbd3fB1QCvS0JWTG/FSVJWKZ1rvk8KzSPQLvPLVbgCzgBp/DMSbsiMjpuAm031vCPTKrdGvBWwbzNtAftwwmZo8aMaYqEYkHVgJNgSxVLfQ5pLBnlW4teEvIfgk0BB7wORxjwsl1QHdgsiXc2rFK9yiIyO+A24ABqvqR3/EY4ycROQG3ROxjYKg1takdS7pHQUQaAXnADuBUWxZjYpmIPAtcAXRX1Q1+xxMpbHjhKKjqAdykWg4wyedwjPGN11/hStxx6pZwj4JVukfJm1R7FzgF6GKTaibWeJNny3Hb5LNsu+/RsUr3KHnjVtfjJtVsp5qJRTcBPXE7zyzhHiWrdI+RiNyLa4YzRFX/6Xc8xoSCiLTBHTS5CBhuk2dHz5LuMRKRhsBaoBzX7LzE55CMCToReRV3qkpXVf233/FEIhteOEaqWoTboZaBa/RhTFQTkfOBkcB9lnCPnVW6deTtOb8YV+3aLK6JSiKSAnwKFAA5qnrQ55AillW6dfcroBD4s4jY+2mi1X3AScAvLOHWjSWJOvLaP94KnA5c5XM4xtQ7ETkVtw3+T6r6od/xRDobXqgH3trdBbhlNFleIjYm4nmnYy8DWgDZdl5g3VmlWw+8ZTMTcWt3H/M5HGPq0y24YuJGS7j1wyrdeiQitwP3A6NU9a9+x2NMXYhIJvAJ8HdVHelzOFHDkm498j6K/QtojVvHuMvnkIw5JiISB7wPdMH9LtuQWT2x4YV65B3lcxVuT/qj/kZjTJ3chGvaf7Ml3PpllW4QiMgU4C7cNsm5fsdjzNEQkU643Zb/AEbYVt/6ZUk3CESkAbACV/F2U9XdPodkTK14wwoLgB64YYUtPocUdWx4IQi8xeM/xy2zsdUMJpLcDAwCfmkJNzis0g0iEbkLmAKMVtVX/Y7HmMMRkSxgFTAPuNiGFYLDkm4QeasZPgLa4YYZrOG5CUve7+qHQAfcsIL9rgaJDS8Ekbea4edAY2Cat3PNmHB0O+40lGst4QaXJd0gU9VPgd8CFwJX+xyOMT8hIn1xq23+YsNgwWfDCyHgdR97F+gHnKyqn/kckjHA9y0bVwENgJ6qutffiKKfVbohoKoVuGGGEmCmN35mTDh4FOgEjLeEGxqWdENEVb/BHdt+Ku6jnDG+EpGLcUNeD6rqIr/jiRU2vBBiIvIcMB44y37RjV9EpC2umc1moL81Jg8dS7oh5o2hrQSSceO7O30OycQYEYkH/gnkAL1sjiG04v0OINaoaoGIjMWt331WRC60RegmGDa1GNQAmOP988JOO94/uKnFoAaXJZ346V+Kt3dOIjChSMst4YaYjen6QFVXAr8BhuO6ORlTr7yEOxc4w/tv7qYWg1LeKdn1wazi7Z0vSmxRvq75gMu955kQsuEFn3gbJeYAQ4HTVHWZzyGZKLKpxaC3ccm2ofdQ0Y6Kg6Uj9qxqnBKIl78ddzKNJK4IWNRpx/vD/Is09lil6xNvSOFKYBvwiog08zkkE8XKVRv+ev+GJvu0XB5rnEkjifM7pJhlSddH3skSo4FWwPN2hLupRxcCS4AigMcKv+Kj0nzuSelIZnwjvMeXeM8zIWR/5D5T1aW4w/9+hhvnNabOOu14/yBwCXBw8cE9/LHoa0YmHs/opBMrn3IQuNh7ngkhG9MNA9747ixgFHCOqv7T55BMhKucSPu6vHjQJXs/SWoRaMBfj+tJwx+GFSor3Z9Z4g0tq3TDgDe+ew2QB8wWkZN8DslEvjlFWj7whn3rk8pQnmiSVTXhgptgG8gPS8pMiFjSDROquh+4GNd45DURaXiEbzGmRhWq/Lbg8wbryw/wf4270D6uYRGQjzfGa/xjSTeMqOpG4AqgN/An679rjlW3XR/Of6NkR9yNDduWDm7QrHIooQ0/TK7ZRJpPbEw3DInIPcDdwGRVfdTfaEykEZEhwPwAvLU+7bSEOHfv/n5HGofsUvMt0BhlSTcMeUvHXsVVIReo6jyfQzIRQkTSgY+BrcAAVd3nc0jmEJZ0w5SINAI+wJ1Z1VdV83wOyYQ5EWkK/At3CnUfVf3C55BMNWxMN0yp6gFgBFAMvCkiaT6HZMKY1zlsFtAZGGkJN3xZ0g1jqvoVbkVDW+BvIpLkc0gmDHkTro/h+nhcb32aw5sl3TCnqh/ijvoZiGsFaf/PzKFuAa7FnQDxlN/BmMOzfroRQFVni0h74AHgC+BOfyMy4UJERgMPAS8Dd/gcjqkFm0iLEN5HyCeBicC1qjrN55CMz0RkEO6U6RXAEFUt9jkkUwuWdCOIN1nyOnA+MFpVX/M5JOMTEekOvA9sBwbasU+Rw5JuhBGRZOA9oBdwnk2axB5vqOlDoAK3FvcrfyMyR8OSbgTyGp4vAVoDZ6rqKp9DMiEiIsfjKtzjgUGqus7nkMxRspnwCKSqu4HzgL3AOyKS5W9EJhREJBV4B7eEcLgl3MhkSTdCqerXwNlAOfAPEenkc0gmiESkMfA2kAVcpKpLfA7JHCNLuhFMVT/DJd5EXOJt63NIJgi8cfw3gFOAS1X1HZ9DMnVgSTfCeR8xzwVSgYWWeKOLl3DfxJ3sO05Vrel4hLOkGwVUdQUu8TbHEm/UqJJwBwPjVfUln0My9cCSbpRQ1Y+Bc4A0XOK1I38imNdlbi5wJi7hvuhvRKa+WNKNIt7JwpWJ932vt6qJMCJyHDAfN6Twc0u40cXW6UYhEcnBLS0qx50uvNbnkEwtiUgLXMLtBlymqq/6HJKpZ1bpRiFvs8QgoAxYJCJ9fQ7J1IKItAEW45aFXWgJNzpZ0o1S3kkTA4HdwD9F5AKfQzKHISLdgI9wuwzPU9W3fQ7JBIkl3SimqpuB04BcYI6IXONvRKY6InI6bmtvHG5r72KfQzJBZEk3yqnqt7glR+8BT4nIFDvaPXyIyBjc+Pt2oL+qrvY5JBNklnRjgKoWAMOB54C7gFki0tDfqGKbOHfjzjVbhmvP+KXPYZkQsNULMcSrcH+DO4FiOW6yZpu/UcUe74b3LDAWeB6YpKol/kZlQsWSbgwSkYuAmbguZaNU9SNfA4ohItIOeA3IAW4Hfq/2RxhTbHghBqnq34D+uOPdF4nIdTbOG3wicg7uaJ1OuE8ZD1rCjT2WdGOUqq7Bda16F3gCeF5EUvyNKjqJSJyI3AnMA7YBp6rqmz6HZXxiSTeGqeoe3ATbPcAVwAoROdnPmKKNiLTErU64Dzdp1s9ryWlilCXdGKeqFao6BTgLSAE+FpFfioj9btSRiJwPrMYN5VwNXKGqB/yNyvjN/rAMAKq6EDgZN9zwB9wxQNap7BiISGMReQr4O/AtcIqqPmvjtwYs6ZoqVHUHbrhhEtAPWCsiV9okW+2JyGBgDa6y/T0u4eb6G5UJJ5Z0zY+o82egB/AJbj3pe9Ym8vBEJE1EngX+iWs0NFBV/8vW35pDWdI11VLVL3Dbh6/DrXJYKyJ3ikiSv5GFFxEJiMg4YD0wHrfxpKeqfuhvZCZc2eYIc0Qi0go3zjsK+DdwC/C3WB+jFJE+uPelH/AxMNFbimdMjazSNUekqltVdTTuVIpC3I6qf4jIqf5G5g8RaS8iM3CJtj1wJTDAEq6pDUu6ptZU9T3cCocbge7AUhF5TUS6+hpYiIjIiSLyOLARGA08CGSo6nRVrfA3OhMpbHjBHBMRaQz8CrgVaIyrfn/nnUwcVUSkPa5R0FVAPPAMcJ+qbvEzLhOZLOmaOhGRNGAyrvptiuvb+wfgbVUt9zO2uvLGbH+J6wZWAcwAHlDVz30NzEQ0S7qmXohIE+BaXPXbEjfh9idghtdIPSJ4/SdGATfgVm0UAE8B/6eq3/gZm4kOlnRNvRKRBOBiXNI6HXci8TzgBWBuXbbByhRpAMzx/nmh3q0Hq3vsGGKOxy2PGweMBJJxS8D+CLygqvuONWZjDmVJ1wSNiGTj1q5egTtwsQjX/OV1YN7RVMBecp2LO2wTYAlwCW4suepjP6tN4vUq88G4G8RwoBmQD7yMu0EsifUlcSY4LOmaoBOROFzVewlwEdDG+9I64B+4ZLkM+KqmRCdT5G3gDKDymKEi4CDQ4JDHFundOqyaGFoAp+KazwwB+uAOgtyLS+av48ahi479JzXmyCzpmpDy+jj0As7GJb+B/JA0v8P1LVgP5OHGhbcAW7ibF5EfJd0fU+AgRexkOU8xBWgHZHr/dcetpwU33LEMl+z/gatoS+v3pzSmZpZ0ja9EpAEuKfbFVZ9dcYny0IbqFSRQQRJxJPBDA55yXL1bgltf8GMHcWtqc3GJdimw0juo0xhfWNI1YcerhlvhqtXW3n9pxJNGV66mnAbfPzkON8CQQAnLuY+DbMZVx18DmyN92ZqJPpZ0TUQ4ZCKtuiGGIo5iIs0Yv9g2YBMp5lBzwsV7fCA/LB8zJizF+x2AMceoutULxoQ9q3RNpLgQN3xQxA9DCW2qeexCvwI0pjZsTNdEjGDtSDMmlCzpGmNMCNnwgjHGhJAlXWOMCSFLusYYE0KWdI0xJoQs6RpjTAhZ0jXGmBCypGuMMSFkSdcYY0LIkq4xxoSQJV1jjAmh/w8gs1s2uBXlHAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"f = lambda x: x**2\n",
"fprime = lambda x: 2*x\n",
"\n",
"x = np.linspace(-1, 1, 100)\n",
"plt.plot(x, f(x), color='black')\n",
"x_init = 0.5\n",
"x_min = newton_solve(f, fprime, x_init)\n",
"plt.scatter([x_init], [f(x_init)], marker='X',s=100, c='crimson', label='Initial guess')\n",
"plt.scatter([x_min], [f(x_min)], marker='X', s=100, c='green', label='Minimum found by Newton solver')\n",
"plt.legend()\n",
"plt.axis('off')"
]
},
{
"cell_type": "markdown",
"id": "12279094",
"metadata": {},
"source": [
"## Now let's test the solvers on a neural net\n",
"\n",
"To make this work, we'll need to define a neural network which is likely not useful in any real setting. I'm calling it `TroughLayer` because it essentially represents a \"trough\" where there is a low point in the middle of the domain surrounded by high walls. You'll see what I mean in a second. This is a carefully designed network with two sigmoid neurons whos weights and biases were chosen so that the output is what I want.\n",
"\n",
"Anyhow, despite being contrived, we can take the `jacobian` function and apply it to the network to get `fprime` that we can put into the Newton solver. As you will see, we find a good minimum. There was some tuning of the tolerance here, if you set it too low you will diverge."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0e529714",
"metadata": {},
"outputs": [],
"source": [
"class TroughLayer(nn.Module):\n",
" \n",
" def __init__(self, activation=nn.Sigmoid):\n",
" super().__init__()\n",
" self.layer = nn.Linear(1, 2)\n",
" self.activation = activation()\n",
" self.init_parameters()\n",
" \n",
" def forward(self, x):\n",
" x = self.layer(x)\n",
" x = self.activation(x)\n",
" x = torch.sum(x)\n",
" return x\n",
" \n",
" def init_parameters(self):\n",
" with torch.no_grad():\n",
" self.layer.weight = nn.Parameter(torch.tensor([[ 1.5],\n",
" [ -1.5]]))\n",
" self.layer.bias[:] =nn.Parameter(torch.tensor([[-12.0, -12.0]]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "848687f3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/numpy/ma/core.py:2830: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
" order=order, subok=True, ndmin=ndmin)\n"
]
},
{
"data": {
"text/plain": [
"(-21.994999980926515, 21.89499959945679, -0.05, 1.5)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dtype = torch.float32\n",
"v = torch.tensor(np.arange(-20.0, 20.0, step=0.1), dtype=dtype)\n",
"\n",
"m = TroughLayer()\n",
"\n",
"model_v = np.hstack([m(torch.tensor([vv])).detach().numpy().flatten() for vv in v])\n",
"\n",
"# Take the derivative fo the model\n",
"dm_dx = partial(jacobian, m)\n",
"# Easier application of the derivative of the model\n",
"dm_dx_flat = lambda x: dm_dx(x)[0]\n",
"\n",
"x_init = torch.tensor([-8.5])\n",
"x_min = newton_solve(m, dm_dx_flat, x_init, tol=1e-4) # <--- NOTE: Have to tune tolerance here to get good performance!\n",
"\n",
"plt.plot(v, model_v, color='black')\n",
"plt.scatter([x_init], \n",
" [m(x_init).detach().numpy()], marker='X',s=100, c='crimson', label='Initial guess')\n",
"plt.scatter([x_min.detach().numpy()], \n",
" [m(x_min).detach().numpy()], marker='X', s=100, c='green', label='Minimum found by Newton solver')\n",
"plt.gca().set_ylim([-0.05, 1.5])\n",
"plt.legend()\n",
"plt.axis('off')"
]
},
{
"cell_type": "markdown",
"id": "fa0bb775-40ea-4bf1-b506-fdea127c4f03",
"metadata": {},
"source": [
"## Time for the linear reservoir!\n",
"\n",
"It's the hydrologist's favorite model! We can pretty easily solve this one analytically, so let's use it to make sure that our solvers are capable of producing good solutions. I'll leave it as an exercise for the reader on solving the ODE analytically, but it's a pretty simple one. As you can see we can solve this equation pretty well numerically. Sure, there's some discrepancy, but this is mainly showing how to get an end to end solution, rather than fine tuning each piece. I'll call this good enough. Let's move on."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e335fd55-46ca-4d23-b755-d822a9b6ce67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'storage')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Reservoir constant\n",
"k = -0.1\n",
"\n",
"# Initial value\n",
"S0 = 1.0\n",
"N_time = 50\n",
"\n",
"# Analytic solution\n",
"t_range = np.arange(N_time)\n",
"S_analytic = S0 * np.exp(k * t_range)\n",
"\n",
"# Format the ODE for the Newton solver\n",
"def f(x, S, k):\n",
" return x - S - k*x\n",
"\n",
"# Derivative of the equation above\n",
"def fprime(x, S, k):\n",
" return 1 - k\n",
"\n",
"# Start with the initial value\n",
"S_newt = [S0]\n",
"S_newt_ad = [S0]\n",
"\n",
"for i in range(N_time):\n",
" # Partials are getting rid of the initial condition and parameter args\n",
" fprime_Sk = partial(fprime, S=S_newt[-1], k=k)\n",
" f_Sk = partial(f, S=S_newt[-1], k=k)\n",
" # Calculate the update\n",
" S_newt.append(newton_solve(\n",
" f_Sk, # Function with conditions at current time\n",
" fprime_Sk, # Derivative with conditions at current time\n",
" (1+k)*S_newt[-1]) # Initial guess - I just made this one up\n",
" )\n",
" \n",
" # Now do the same, but with an autodiff calculated derivative\n",
" fprime_Sk_ad = partial(jacobian, f_Sk)\n",
" S_newt_ad.append(newton_solve(f_Sk, fprime_Sk_ad, torch.tensor((1+k)*S_newt[-1])))\n",
"\n",
"plt.plot(S_analytic, color='black', linewidth=3, label='Analytic')\n",
"plt.plot(S_newt, color='crimson', linestyle='--', linewidth=2, label='Implicit Euler (Hand calculated derivative)')\n",
"plt.plot(S_newt, color='orange', linestyle=':', linewidth=2, label='Implicit Euler (Autodiff calculated derivative)')\n",
"plt.legend()\n",
"plt.xlabel('time')\n",
"plt.ylabel('storage')"
]
},
{
"cell_type": "markdown",
"id": "fcd38efc",
"metadata": {
"tags": []
},
"source": [
"## Now the nonlinear reservoir!\n",
"\n",
"Of course, a linear reservoir is too easy. Let's change the ODE so that the conductivity term, K, is now dependent on the current storage. I'll just make a nice little function here, where basically, if you have low storage you start filling up and if you have high storage you start draining. Steady state is at a nice even value of 50."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b08f7e01",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Reservoir conductivity (K)')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dtype = torch.float32\n",
"v = torch.tensor(np.arange(0.0, 100.0, step=0.1), dtype=dtype)\n",
"\n",
"def nlres(t, x, k):\n",
" return k(x)*x\n",
"\n",
"def kx(x, b=50, s=-0.1):\n",
" return s * torch.tanh((x-b)/10)\n",
"\n",
"plt.plot(v, kx(v))\n",
"plt.axhline(0, color='grey', linestyle='--')\n",
"plt.xlabel('Storage (S)')\n",
"plt.ylabel('Reservoir conductivity (K)')"
]
},
{
"cell_type": "markdown",
"id": "f641618c-0334-4693-8f4b-dcbb31fc82a5",
"metadata": {},
"source": [
"## Solving the nonlinear reservoir\n",
"\n",
"Given all of our machinery developed so far, we can now solve the nonlinear version for all sorts of initial conditions. Note that I am using autodiff's `jacobian` to determine the derivative of the conductivity function $K(S)$."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b793ad6d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" del sys.path[0]\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'storage')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def f(x, S, k):\n",
" return x - S - k(x)*x\n",
"\n",
"def fprime(x, S, k):\n",
" return 1 - jacobian(k, x) * x - k(x)\n",
"\n",
"def my_odeint(f, fprime, S_newt, kx):\n",
" N_time = 30\n",
" for i in range(N_time):\n",
" f_Sk = partial(f, S=S_newt[-1], k=kx)\n",
" #fprime_Sk = partial(fprime, S=S_newt[-1], k=kx) # <---- \"hand done\" derivative function\n",
" fprime_Sk = partial(jacobian, f_Sk) # <---- autodiff derivative function\n",
" S_newt.append(newton_solve(f_Sk, fprime_Sk, torch.tensor((1+k)*S_newt[-1])))\n",
" return S_newt\n",
"\n",
"all_S_newt = []\n",
"for S0 in [90, 70, 50, 30, 10]:\n",
" all_S_newt.append(my_odeint(f, fprime, [S0], kx))\n",
"\n",
"plt.plot(all_S_newt[0], color='red', label='$S_{0}$=90')\n",
"plt.plot(all_S_newt[1], color='orange', label='$S_{0}$=70')\n",
"plt.plot(all_S_newt[2], color='gold', label='$S_{0}$=50')\n",
"plt.plot(all_S_newt[3], color='green', label='$S_{0}$=30')\n",
"plt.plot(all_S_newt[4], color='blue', label='$S_{0}$=10')\n",
"plt.legend()\n",
"plt.xlabel('time')\n",
"plt.ylabel('storage')"
]
},
{
"cell_type": "markdown",
"id": "5948bb88",
"metadata": {},
"source": [
"## Getting to the good stuff: A reservoir with a neural network for $K(S)$\n",
"\n",
"Now imagine you're actually a hydrologist - you can't directly measure the conductivity of the \"reservoir\" but you can measure storage levels. If you're interested in determining $K(S)$ from data you have all sorts of avenues, but the one we're interested in is a neural network. In this case, imagine we \"know\" what the dynamics look like (aka the ODE defining the system), but we do not know the functional form for $K(S)$. Our neural network then, will solve the dynamics, and update the weights of a network that represents the conductivity during training. Once trained, we can pull out the network and look at what $K(S)$ was determined to be from the data.\n",
"\n",
"Below is the network in question. Note it's just a simple MLP with a single `width` hyperparameter. You might modify this to be a more complex structure, but for simplicity let's give this a go."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "79136a9c",
"metadata": {},
"outputs": [],
"source": [
"class NeuralReservoir(nn.Module):\n",
" \n",
" def __init__(self, width):\n",
" super().__init__()\n",
" # Parameterize K(S) as a neural network\n",
" self.K = nn.Sequential(\n",
" nn.Linear(1, width),\n",
" nn.Tanh(),\n",
" nn.Linear(width, width),\n",
" nn.Tanh(),\n",
" nn.Linear(width, 1),\n",
" nn.Tanh())\n",
" self.dK_dx = partial(jacobian, self.K)\n",
" \n",
" def forward(self, S0):\n",
" # Reservoir equation (aka the \"dynamics\")\n",
" def f(x, S, k):\n",
" return x - S - k(x)*x\n",
" \n",
" f_Sk = partial(f, S=S0, k=self.K)\n",
" fprime_Sk = partial(jacobian, f_Sk) \n",
" S1_guess = torch.tensor(0.95*S0)\n",
" return newton_solve(f_Sk, fprime_Sk, S1_guess)\n",
" "
]
},
{
"cell_type": "markdown",
"id": "90ce48b4-81ab-47b9-8fca-87be9f6db02f",
"metadata": {},
"source": [
"## Training data\n",
"\n",
"Before we can train the network we need some data - this is where we use our previous solution to generate some synthetic data and we will see how well the network can reconstruct the known conductivity function. I'll just run a few timesteps for a bunch of different initial conditions."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4d024f8b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:13: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" del sys.path[0]\n"
]
}
],
"source": [
"all_S_newt = []\n",
"for S0 in [0.1, 1, 2, 5, 10, 15, 20, 30, 40, 60, 70, 80, 85, 90, 92, 95, 97, 100, 110, 130, 150, 200]:\n",
" all_S_newt.append(my_odeint(f, fprime, [S0], kx))"
]
},
{
"cell_type": "markdown",
"id": "d3be1796-e3e2-4afb-a9a3-8a109a763de4",
"metadata": {},
"source": [
"## A standard epoch function\n",
"\n",
"As usual, let's define this and get it out of the way."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "adbb5903",
"metadata": {},
"outputs": [],
"source": [
"def epoch(X, y, model, loss_fun, device=device, opt=None, monitor=None):\n",
" total_loss, total_err, total_monitor = 0.,0.,0.\n",
" model.eval() if opt is None else model.train()\n",
" n_iter = X.shape[0]\n",
" for i in tqdm(range(n_iter), leave=False):\n",
" Xd, yd = X[i].to(device), y[i].to(device)\n",
" if opt:\n",
" opt.zero_grad()\n",
" yp = model(Xd)\n",
" loss = loss_fun(yp, yd)\n",
" if opt:\n",
" loss.backward(retain_graph=True)\n",
" if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:\n",
" opt.step()\n",
" total_loss += loss.item() * X.shape[0]\n",
" if monitor is not None:\n",
" total_monitor += monitor(model)\n",
" return total_loss / len(X), total_monitor / len(X)"
]
},
{
"cell_type": "markdown",
"id": "9c075d39-952f-4785-91d6-006e266af400",
"metadata": {},
"source": [
"## Split out the input/output data\n",
"\n",
"Our networks training task is to figure out the storage one timestep later, given it's initial storage. We have 660 samples for training, but you could add more by adding more `S0`'s under the `Training Data` heading or by increasing the number of timesteps in the `my_odeint` function."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fa5179cf-eea8-4735-819c-35c15f8e7f97",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([660, 1])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_X = []\n",
"train_Y = []\n",
"for ss in all_S_newt:\n",
" train_X.append(ss[0:-1])\n",
" train_Y.append(ss[1:])\n",
"\n",
"train_X = torch.tensor(np.hstack(train_X).reshape(-1, 1), dtype=torch.float32)\n",
"train_Y = torch.tensor(np.hstack(train_Y).reshape(-1, 1), dtype=torch.float32)\n",
"\n",
"train_X.shape"
]
},
{
"cell_type": "markdown",
"id": "73f78dfa-7154-4841-b3a0-16efb627b8af",
"metadata": {},
"source": [
"## Let's train!\n",
"\n",
"I'll set up the NeuralReservoir with some user-defined width and train a user-defined number of epochs. Feel free to play with these. Following training we can look at the loss curve to see if we have learned anything..."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "242c9b50",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6d3108395b5e4dfbbc7a14b3410cf424",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:22: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" \"\"\"\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/torch/nn/modules/loss.py:528: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1, 1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/ipykernel_launcher.py:13: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.\n",
" del sys.path[0]\n",
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/torch/nn/modules/loss.py:528: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bzq/miniconda3/envs/all/lib/python3.7/site-packages/torch/nn/modules/loss.py:528: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1, 1, 1, 1, 1, 1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/660 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"max_epochs = 30\n",
"width = 16\n",
"\n",
"model = NeuralReservoir(width)\n",
"loss_fun = torch.nn.MSELoss()\n",
"opt = torch.optim.Adam(model.parameters())\n",
"\n",
"loss_history = []\n",
"for i in tqdm(range(max_epochs)):\n",
" train_loss, train_mon = epoch(train_X, train_Y, model, loss_fun, opt=opt)\n",
" loss_history.append(train_loss)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "95e88bee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'MSE Loss')"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(loss_history)\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('MSE Loss')"
]
},
{
"cell_type": "markdown",
"id": "ce8df5c0-3b5e-4b20-984a-352107e5cc8f",
"metadata": {},
"source": [
"## What did the network actually learn though?\n",
"\n",
"Of course, our network was evaluated on how it was able to predict the next timestep's storage, but we were interested in getting $K(S)$ out. Lucky for use, we can just pull it out with `model.K` and start inputting storage values. Let's see how we did!\n",
"\n",
"That's the end of what I've got here - but hopefully it's been informative and you can see some other applications of this type of approach."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "807244f3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Reservoir constant')"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"dtype = torch.float32\n",
"v = torch.tensor(np.arange(0.0, 100.0, step=0.1), dtype=dtype)\n",
"model_k = model.K\n",
"yhat = np.hstack([model_k(torch.tensor([vv])).detach().numpy().flatten() for vv in v])\n",
"plt.plot(v, kx(v), color='black', label='target')\n",
"plt.plot(v, yhat, color='crimson', label='neural net')\n",
"plt.legend()\n",
"plt.xlabel('Storage')\n",
"plt.ylabel('Reservoir constant')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e388eae4-7706-4a15-a872-15109649bdd2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "all",
"language": "python",
"name": "all"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment