Last active
November 12, 2021 15:39
-
-
Save smsharma/6530a0cbd09b9ab8c57ceb7a9a7703f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a11541a8-c5da-4f2b-8a44-aad993c2bc4d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import functools\n", | |
"import sys\n", | |
"\n", | |
"from tqdm.notebook import tqdm\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import gpytorch\n", | |
"from gpytorch.models import ApproximateGP, PyroGP\n", | |
"import pyro\n", | |
"from pyro.nn import PyroParam, PyroModule, PyroSample\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"sys.path.append(\"../../fermi-gce-gp/notebooks/\")\n", | |
"\n", | |
"%matplotlib inline\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "df09d0b4-7e29-47b0-afeb-506c61d65252", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pylab as pylab\n", | |
"import warnings\n", | |
"import matplotlib.cbook\n", | |
"\n", | |
"from plot_params import params\n", | |
"\n", | |
"warnings.filterwarnings(\"ignore\",category=matplotlib.cbook.mplDeprecation)\n", | |
"\n", | |
"pylab.rcParams.update(params)\n", | |
"cols_default = plt.rcParams['axes.prop_cycle'].by_key()['color']" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "443756ba-7f77-48f2-9f91-44f82c657e2b", | |
"metadata": {}, | |
"source": [ | |
"Define an \"observable\" `obs(T_nu, T_gamma)` that represents a mapping from latent temperatures/Hubble to observations (e.g., elemental abundances)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "ffdd403b-969d-4f48-9fde-3470e3d31f59", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def obs(T_nu, T_gamma):\n", | |
" return torch.log(T_nu ** 4 * T_gamma + 0.1)\n", | |
"# return np.log(T_nu.detach() ** 4 * T_gamma + 0.1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d7329a64-29f3-4491-934e-0739ad53cf0b", | |
"metadata": {}, | |
"source": [ | |
"GP regression model." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "4f069eaf-845f-4a56-bc2d-bfc092c6e0b0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class GPRegressionBBN(ApproximateGP):\n", | |
" def __init__(self, obs_sample, name_prefix=\"bbn\", num_inducing=30):\n", | |
" \"\"\" Module for fitting a latent GP to observations\n", | |
" :param obs_sample: Sample of observed values\n", | |
" :param name_prefix: Run tag\n", | |
" :param num_inducing: Number of inducing points for sparse GP\n", | |
" \"\"\"\n", | |
"\n", | |
" self.name_prefix = name_prefix\n", | |
" self.num_inducing = num_inducing\n", | |
" \n", | |
" inducing_points = torch.linspace(0.1, 20, num_inducing)\n", | |
" \n", | |
" # Initialize GP variational strategy\n", | |
" variational_strategy = gpytorch.variational.VariationalStrategy(self, inducing_points, gpytorch.variational.CholeskyVariationalDistribution(self.num_inducing), learn_inducing_locations=False)\n", | |
" super().__init__(variational_strategy)\n", | |
" \n", | |
" # Specify GP mean and covariance structure\n", | |
" self.mean_module = gpytorch.means.ZeroMean()\n", | |
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n", | |
" \n", | |
" self.obs_sample = obs_sample\n", | |
" \n", | |
" def forward(self, x):\n", | |
" \"\"\" Forward samples from the GP\n", | |
" \"\"\"\n", | |
" \n", | |
" mean = self.mean_module(x)\n", | |
" covar = self.covar_module(x)\n", | |
"\n", | |
" return gpytorch.distributions.MultivariateNormal(mean, covar)\n", | |
"\n", | |
" def guide(self, x):\n", | |
" \"\"\" Forward samples from the GP variational guide distribution\n", | |
" \"\"\"\n", | |
"\n", | |
" # Get q(f) - variational (guide) distribution of latent function\n", | |
" function_dist = self.pyro_guide(x)\n", | |
"\n", | |
" # Use a plate here to mark conditional independencies\n", | |
" with pyro.plate(self.name_prefix + \".data_plate\", dim=-1):\n", | |
" # Sample from latent function distribution\n", | |
" gp_guide_sample = pyro.sample(self.name_prefix + \".f(x)\", function_dist)\n", | |
"\n", | |
" def model(self, x):\n", | |
" \"\"\" GP model\n", | |
" \"\"\"\n", | |
" \n", | |
" pyro.module(self.name_prefix + \".gp\", self)\n", | |
" \n", | |
" with pyro.plate(self.name_prefix + \".data_plate\", dim=-1):\n", | |
" \n", | |
" # Get prior distribution of latent function and sample from it\n", | |
" function_dist = self.pyro_model(x)\n", | |
" function_samples = pyro.sample(self.name_prefix + \".f(x)\", function_dist)\n", | |
" \n", | |
" # (Exponentiate) to ensure non-zero GP\n", | |
" T_nu = function_samples#.exp()\n", | |
" \n", | |
" # Get corresponding observations\n", | |
" obs_sample = torch.Tensor(obs(T_nu, T_gamma))\n", | |
" \n", | |
"# # If using custom likelihood\n", | |
"# log_likelihood = self.log_likelihood(obs_sample)\n", | |
"# return pyro.factor(self.name_prefix + \".log_likelihood\", log_likelihood)\n", | |
" \n", | |
" # Sample from observed distribution\n", | |
" pyro.sample(self.name_prefix + \".y\",\n", | |
" pyro.distributions.Normal(loc=obs_sample, scale=1.), \n", | |
" obs=self.obs_sample)\n", | |
"\n", | |
"# def log_likelihood(self, obs):\n", | |
"# \"\"\" Custom log-likelihood\n", | |
"# \"\"\"\n", | |
"# delta_obs = 1.\n", | |
"# chi = (obs - self.obs_sample) / delta_obs\n", | |
"# return - chi.square()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1aa9eaee-b021-480e-b997-8697f6ba8e83", | |
"metadata": {}, | |
"source": [ | |
"Define a toy latent function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "e84e0e77-6906-43d0-bffe-6d0ef031d77e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def T_nu(T_gamma):\n", | |
" \"\"\" The true latent function\n", | |
" \"\"\"\n", | |
" return (T_gamma ** 3 + 2.).log()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"id": "596b478a-32bb-4ad5-8f62-e92d7acf9e8a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0, 0.5, '$T_\\\\nu(T_\\\\gamma)$\\\\,[arb.]')" | |
] | |
}, | |
"execution_count": 22, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAFRCAYAAAC1yZDjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAt5ElEQVR4nO3deXwV9b3/8dc3C4GQ5SQhhIgiBCxYbaEQVFBRWRXFKirebmpvK1p7218XK1rbn9Heq4LW9nevrTXaattrLYu7VSugrYALm3VrUSHsBJKQnCQQQnKS7++PcxISsp3MOcmc5f18PHiUzJyZ83E68GbmO9/PGGstIiIiTiS4XYCIiEQvhYiIiDimEBEREccUIiIi4phCREREHEtyu4BwMcbUAclAudu1iIjEkFyg0Vqb2tlKEyuP+BpjGhMSEpLy8vJ6vW1tbS3p6emOv7uuro7U1E6Pb59u6/b2bh63ULd3+7iHcuzcrt3N465zrv+P3YEDB2hubvZZa5M7/YC1NiZ+Afvy8/OtE/7D4NykSZNc2dbt7d08bqFu7/ZxD+XYuV27m8dd55xzTo9dfn6+BfbZLv7u1ZiIiIg4phAJg4ULF7qybSRs7+Z3R/Nxd/O7o/m4hyqa/9vdPnZdiaUxkX35+fn5+/btc7ItsXIc+pOOm3M6ds7ouDnn9NidcMIJlJaWllprT+hsva5ERETEMYWIiIg4phAB7rjjDrdLiEo6bs7p2Dmj4+ZcXx07jYmIiEiXNCYiIiJ9RiEiIhLjmurqaKio6JN9x0zvLBER8bPWUr97N9UbN1GzaTO1//yIrKlTKfjRzWH/rpgKkbq6OgoLCwH/xJxInZwjIhJuTfVHqX3/Pao3bqJ64yYaysvara/Z/C62qQmTmBjU/oqLiykuLqbCfwXTZdMuDayLiESpo6X7qd64keqNG6l9/wOafY0dPpMybBiZkyaRMXEimRO/gEnq3bVDTwPrMXUlIiISy6zPx6EtW6jesBHv+g3U79nd4TMmKYn0008nc9IkMgsnkTJ8OMaYPqtJISIiEsF8hw5Rs2kz3vXrqd64iaa6wx0+MyA7h8zJhWQWFpI+/vMkDhrUb/UpREREIkz9vn1Uv7Me7/r1HPrnv7DNTcd9wpA2diyZZ0wmc3Ihg0aO7NOrje4oREREXGabmjj8ySd431mP9531nd6mShyUSsbEiXgmF5JROInkzEwXKu1IISIi4oLmo0epee89vG+9TfWGjTRWezt8JiUvj8zJk/GceQZpp59OQi8HxftD5FUkIhKjfLW1/kHxt9+mevNmmo8ePe4TgdtUZ07Gc8YZDBwxwrXbVMFSiIiI9KGGigq8b7+D9623qf3www7jGwlJyaRPmIDnrDPxTJ5McnaWS5U6oxAREQmz+tJSvG++RdWbb3H4k487rE8anEbmGZPxnHUmGV/4Qr8+TRVuChERkRBZazmyc2drcBzZuaPDZwbkDPFfbUw5i7TTTovI8Q0nYuO/QkSkn1lrqdtWgnfdOqrefIv6fXs7fGbg8OFkTZmCZ8pZpJ5ySsSPbzihEBERCZK1lrqtW6la9yZV69ZxdP/+Dp9JHTWKrKlT8UydwqARI1yosn8pREREutEaHGvX+YPjwIEOnxn8mbFkTZ1C1tSppOQPc6FK9yhERESOY63lyPbtVL6xpssrjrRx48g65xw8U6aQMjTXhSojQ0yFiFrBi0gojuzeTdUba6h8Y02nYxxp48aRde45ZE2dyoAhQ1yosP+oFbyISBCOlu6ncu0aKv++ptOnqtLGBoLj7NgPjs6oFbyIyHEaK6uoXLuWqjfWcOjjLR3Wp44eQ/a555B17jmkDB3qQoXRQyEiInGh6fBhqt56m8q//53a997H2uZ26weNGEH2tGlkTTuXgfn5LlUZfRQiIhKzmn0+ajZvpvL1v+F9Zz3NjQ3t1qcMzSP7PH9wpI4c6U6RUU4hIiIxxVrL4Y8/pvL1v1G5Zg2+2tp265MyMsk+9xyyzz+PwWPHxuQEwP6kEBGRmHB0/34Ovv43Kl//G/Wl7R+wSUhJwXPWWeRccD4Z48f3+j3j0jUdSRGJWk2HD1O17k0OvvYatR991G6dMQmkT5hAzvTz8Zx5ZlQ3OYxkChERiSq2qYna9z+gYtVqvG+/TXND+3dypI4aRc706WRPmxZ1bdWjUUSHiDGm7WxBj7V2iWvFiIir6ktLObhqNQdXv0bDwYp265Kzssm54HyyLzhfA+T9LGJDxBhzS9vQMMZ4jDGLrbWL3KxLRPpPU309VWvXcXDVqg63qxKSkvFMOYucGdM1zuGiSD7qo9v+YK31GmM8LtUiIv2k5emqildXUbVmDU31R9qtH/yZsQyZMZ2s86aRNHiwS1VKi0gOkUJjzERr7Wa3CxGRvuerqeHg63+j4tVXObJrV7t1yZkecmZMJ2fG9Lhorx5NIjlEFgGbjDGLrLVLAuMjupUlEkOstf5B8r/+laq33sL6fK3rTEIimZMmkjNrFpmTC2PmTYCxJmL/X7HWrjLGTMIfJLcBk6y1XpfLEpEwaKyu5uCq1VT89dUOczpShuUzZPYscmZMZ0B2tksVSrAiNkSMMQXATCALWAxsM8bMstau6mqb2traLmef3nHHHRQVFfVFqSISBGsthz78iPKXX+5w1ZGQlIxn6lRy58wm7XOnaxa5C4qKirjzzju7Wp3e1YqIbQVvjHnYWntDm5+vBB4BRnV2RaJW8CKRyXf4MAdXr6b85Veo37On3bqBJ55I7oUXknPB+SRlZLhToHQrKlvBBwJjZdtl1toVba5OVrhSmIgE7fDWbZS/9BKVf3+j3YRAk5RE1tlnk3vhHNJOO01XHVEuIkOkG3pSSySCNTc0ULV2HWV/eYnDn3zcbl3KsHxyL7qQnBkzSM7UVUesiNQQWYX/1tXxVxyzNNlQJPI0lJdT/vIrlP/1VXw11a3LjUkg84zJ5M6dS8aE8ZiEBBerlL4QkSESmFi4yBhzC+ANLPYA97hWlIi0Y63l0D//SdkLL+J9621sc1PruuRMD0PmzCb3wjkMyM11sUrpaxEZIgDW2hJAvbJEIkxzQwOVb7xB2XMvULdje7t1aWPHkXvJxWSdPZWE5GSXKpT+FLEhIiKRpaGykvKXXqb85Vfa37JKSiJ72jSGzruEwWPGuFihuEEhIiLdqttWwoHnnqNyzZp2czuSs7IZevFchsyZTbLH416B4iqFiIh0YJubqd64kQPPPkftBx+0Wzf4M2PJu3QenrOnqhWJKERE5JjmhgYOvv46B559rt3EQGMS8EydSt5ll5I2bpyLFUqkUYiICL6aGspfepmyF/9CY7W3dXnioFSGzJ7F0HmXkJKX516BErEUIiJx7GhZGQeeeZaKlStpPnpsVvmAIUMYOm8eQ+bM1js7pFsKEZE4dGTHTvY/9RSVb6xpN78jdeQo8q64nKxzztF4hwRFZ4lIHDm0ZQv7ly3Hu2FDu+UZ48cz7IorSJ8wXr2spFdiKkTq6uooLCwEYOHChSxcuNDlikTcZ62l9h/vUbpsGbUffti63JgEss6eSt4V8zW/QzooLi6muLiYiooKgNSuPhexreB7S63gRdqz1lK9fj2lS5dx+NNPW5ebpCRypk9n2JVXMDA/38UKJRpEZSt4EXHONjdTte5NSpcu48jOHa3LE1IGknvRheRdfpneGChhoxARiRG2qYmqtev84bF7V+vyxNTBDJ13CXmXztOLnyTsFCIiUc42NVG5Zi2lS5e2myCYlJ5O3mWXkXvxXD2mK31GISISpWxzM1Vr17HvySfbhUdypoe8yy8jd+5FJA4a5GKFEg8UIiJRxjY3433rbfb96U8c2XXstlVypodhV8xnyEUXkTgwxcUKJZ4oRESihLWW6g0b2ffEE9SVlLQuT8rIZNgV88mdO1fhIf1OISISBWrf/4C9f/gjhz7e0rosKS3dHx6XXEziwIEuVifxTCEiEsEOf/ope//wR2r+8Y/WZYmpqeRddhl5l84jUQPm4jKFiEgEqt+7l71/fIKqdWtblyUMSGHovEsYdsV8ktLTXaxO5BiFiEgEaaispPTJP1Px6srWxogmKYncOXPIX7CA5OwslysUaU8hIhIBmurqOPDMs+x/5pk2LdkNOeefxwlf+TIpw4a5Wp9IVxQiIi6yPh8VK1ex74k/tXsZVOakSQy/7lpSR450rTaRYMRUiKiLr0SLlsd19zz2OPV7drcuHzzmFIZfdy0Z4z/vYnUi6uIrErHqduxgz6O/pea991qXDcgdyonXXkPWuedgEhJcrE6kPXXxFYkQjV4v+/73Cf+guW0G/M0R8xdcxdB5l5AwYIDLFYr0nkJEpI81+3yUv/Ai+/78Z5rq6gAwCYnkXnQh+V/6EsmZ6qwr0UshItKHqjduYvcjj1K/b2/rssxJkzjx37/OoBEjXKxMJDwUIiJ9oL60lD2PPNruXeYDhw/npG98g8zJhS5WJhJeChGRMGqqP8r+FSvY/9RTWJ8P8LcpOeHf/o3ceZeQkKQ/chJbdEaLhIG1lup31rOr+BEaystalw+ZMYPh115DcpZmmktsUoiIhOjo/v3sevgRqjceu3WVOnoMI751A2ljx7pYmUjfU4iIONTs83HgmWco/fMymhv8rUqSBqcx/JqvMWTObExiossVivQ9hYiIA7Uf/ZOdD/6q3WzzITNmMPzr15GcmeliZSL9SyEi0gu+Q4fY89jjVLz6auuyQSeNYMRNN5J++ukuVibiDoWISBCstVStXcfu4kdo9FYBkJA8gPwvXU3e5ZfrqSuJWzrzRXrQcPAgux76Dd533mldljFhAiNu+hYD8/NdrEzEfQoRkS5Ya6l4dSV7fve71nYlSenpnPTNb5J9wfkYY9wtUCQCxFSIqBW8hMvRAwfY+T8Ptuu0mz3tPE5a+E0NnEtcUCt4EQdsczPlL7/Cnscep/loPQADsnMY8e2b8Jwx2eXqRPqfWsGLBOloWRk7/9//UPP+sauPIbNnc+K/f52kwYNdrEwkcilEJO5Zazm4ajW7H3mUpiP+sY8BQ4Yw8rvfJeMLE9wtTiTCKUQkrjVWVbHzwV/hXb++ddmQWbM58Ru6+hAJRrchYoxZGuL+DWCBDdba+0Pcl0hYVb31NjsffBBfTQ0AyVnZjPzud8gsnORyZSLRo6crEWOtXRDql4QhjETCpunIEXY/8lsqVh6bdZ59zjmMuOlbJKWnu1iZSPTpKURKwvQ928O0H5GQHP7kE0ru+zlH95cC/necn3zTt8g+b5rLlYlEp25DxFp7azi+JFz7EXHKNjdz4Oln2Pu/T2Cb/C+LSv/c5xj1/e8xIDfX5epEopcG1iXmNVRWsuOBX7ROHDRJSQz/6lfJu/wyTEKCy9WJRLewhogx5nrgXmttTpj25wEWAl6gEvBaa1eFY98SH6o3bWL7A7/EV1MNQMqwfAoW/YjBY8a4XJlIbAj3lcgywjSOEgiQ5dbaWYGfC4BNgN4zKj1q9vnY979PsP+pp1qX5UyfzogbFpKY2mUHBxHppbCGiLW2Glgdpt3dBjzcZt8lxhg9eyk9aigvp2TxfRz6eAsACSkDOfnb3yLnggtcrkwk9jgOEWPMSGAJMCOwaBWwyFq7I/SyALiF4646rLXhelpMYlT1pk1s//kD+GprAUgdOYqCW29h4PDhLlcmEpschYgxZhSwHFjKsauFicByY8xVoQZJ4NYVQLYxZib+MZGJ1toloexXYpdtamLfk3+mdOky/PNbIffCCznp+m+SMGCAu8WJxDCnj6ZcYa0ttNbeZ61dHfh1n7V2MnBlGOoqwD+24rHWrggMpm82xjzc3Ua1tbUYYzr9VVRUFIayJBL5amr49M67KF26FLAkpKQw6oc/5ORv36QAEQlSUVFRp393lpaWAnQ5C9dRK3hjzAxrbadjH92t68X+ZwIrrbXmuOVVwChrrbeTbdQKPg4d3rqVbXffS0N5GQADTzyJ0bctYtCIES5XJhIb+qoVfGU36w463GdbJfhvYXWmANgchu+QKFexajW7fvVrmn2NAGSdfQ4j/893SBw0yOXKROKH0xCxxpiRx499BAbbs0MtCn9IecKwH4lBzT4fe377O8pefBEAk5DI8Guv8U8e1CtrRfpVjyFijNlAx7/QDTDKGFNy3LIs4PpQi7LWeo0xm40xnuNuXXmstboKiWON1TWULF5C7QfvA5CUkUHBLbeQMf7zLlcmEp+CuRIxwAK6vr10vO5udfXGPfhnqy8BMMZc2fJ7iU9Hduxk68/+k6NlBwBILShg9O0/JmXoUJcrE4lfwYTI9dbad/u8kuNYa1cYYxYaYxYSuBKy1i7q7zokMnjXr2f7fT+nqf4IANnTpnHyd75D4sAUlysTiW89hkhnAWKMyQisq+mLotp8d3Ff7l8in7WWA888y57HHqdl/sfwa77GsCuv1PiHSARwOrC+AngV0NsKpc80+3zseug3VLzqf3lUQspARv3wB2RNOcvlykSkhdMQedha+1RnK4wxE6y1/3Bekgj4Dh+m5J57W9u3D8gZwpg7fkrqqFEuVyYibTmdsb7NGDO9i3VXOy1GBOBoWRkf/2hRa4AMHnMK4x64XwEiEoGcXokswd/X6gu0b/1ugFH4O/CK9FrdthI+vfMuGqv8D/llTZnCyB/8QAPoIhEqlFbw19PxsV8D3BvCPiWOVW/eTMk9i1ufwMr74hc58d+/rrcPikQwpyGyqKvHfo0x94RQj8Spg6+9zo7//p/A+88NJ13/DfIuvdTtskSkB45CpIvHfkfhbwe/LdSiJL7sf/oZ9jz2GAAJScmMuvmHZJ091eWqRCQYYbtPYK3dHnhiqzBc++yturo6CgsLKSwspLhYU0winbWW3b/9XWuAJKYO5pSf3akAEYkAxcXFFBYWUlFRAdDlO6UdtYIHMMb8CJhJ+4aLHmCFtbbfB9bVCj66WJ+PnQ/+iorV/rcGJGdlc8pdRaSOHOlqXSLSXp+0gg8ECMCt+K88SvD3zCrA/5pckS41Nzay/f6fU/XmmwAMzD+BU352Jyl5eS5XJiK95XRg3WutfQTAGOMFbKAt/LvGmAnAP8JRnMSepvqjbLv7bmre9Q+rpY4cxSl3FZGcleVuYSLiiNMxkdYXT1lrt+O/rdUiHO8TkRjUVFfH1qKi1gBJO/VUPnPv3QoQkSjmNESMMWakMebmwM+Fxpjxgd/PCkNdEmN8hw/z6f8tovajjwDImDCBU+66k6TBg12uTERC4ShEAk9hXQWMCSy6FXjdGNNEeF6PKzHEV1vLJ7f/lEMfbwHAM3kyY376ExIHDnS5MhEJleMZ69ba+9r83ou/DcqowO0tESAQID/5KXUl/u44WVOnMupHN5OQFEqzBBGJFN1eiXTTZLFTXQVIb/cjseH4AMk+91wKFCAiMaWn21k3hOl7wrUfiRIdAmTaNEb94PsYBYhITOnpT/QsY8xS/I0Vnc1K9G8rccR3+DCf/PSO9gHy/e8pQERiULd/qq21elxXeqWpro6td9xJ3batgP8WlgJEJHapx7aETVP9Ubb+7D9bn8LKmjJFt7BEYpxCRMKi2edj2933UPvhhwBkFk5m1C0/UoCIxDiFiITMNjWx/f4HqHl3MwAZ48cz+rZFegpLJA7EVIioFXz/s9ay89cPUbVuLQBpY8cx+ie3kzBggMuViUgo+rwVfKRRK3h37Hn89+x/6ikABp08krH33k1SWprLVYlIuPTUCj6mrkSkfx14/vnWAEkZls8pdxUpQETiTK9uWhtjMoCr8TdZLAAy26z24n+vyFJglbW2Jkw1SgSqXLOG3Y/8FoDkrCw+c9edDMjWE+Ei8SaoEAm8P30x/gmHS4FFQKW1trrNZzLxt4GfCCwxxmQB91hr/xHuosVdNe+9z/YHfgFYEgelckrRHaTkD3O7LBFxQY8hYoy5Hv9LpxZ097lAoFQD24GnWrY1xsy01t4fjmLFfUd27GTb3XdjfT5MUhKjb7+N1IICt8sSEZd0GyLGmCuAZW2vOHrDWvuIMSbTGDPfWvu0owolYjRUVvLpXXfRVFcHwKjvfY+M8eN72EpEYllPbU+eCvULAgGkAIlyTfX1bPvZf9FQXg7AiddeS/Z501yuSkTcFrans4wxM9TyPTb5JxP+nMNbPwVgyKzZ5F0x3+WqRCQShG1KsbV2NbR7d0ilBtVjw94//BHvO+8A/tfajrjpRoxRc2YRCWOIABhj5uMfXAcYbYwpBKqALPyP/e4I5/dJ3zu4+jX2P+2/GznopBEU3Kp2JiJyTNj+NjDG/BVY1NXVR+ARYIkih7ZsYceDDwKQlJ7OmJ/+hKTBg12uSkQiSThnrI/u7vaV0ye8xB0N5eVs+697/I/yJiYx+tZbNRdERDroMUSMMQeNMRuMMfcYY24OzFrvzMOB21kS5ZobGth29700eqsAGHHjDaR//nMuVyUikSiY21nbrbWTe/qQtfa+wBNaIzX2Eb2stez69UOtT2INnTuX3AvnuFyViESqYG5nlQS7M2vtajcDRK3gQ1f+0stUrF4NQNqpp3Li9d90uSIRcUPYWsEbY35jrb0xzPWFnVrBh+7Qv/7Fxz++HevzkZyVzam/fEBNFUXiXDhawcfGC0ekW43V1ZTcu+RYT6wf36oAEZEeBRMiVxtjPjXGPGSMubybgXWg3WRDiRK2qYntP3+AhsqDAJz0jW+QNm6cy1WJSDQIJkQ2Au/if4/IU0BVD6FyVbiLlL5Vumw5Ne++C0D2ueeSe/FclysSkWgRzNNZm621t0LrhMFZwMzA/94AWGNMCbASWIX/ZVUSJWree499f3oSgIHDh3Pyf3xbLU1EJGjBhEjrTPPAhMEVgV/Hh8ps4EY0hhI1Gr1ett//AGBJGJBCwa2LSEzt8iEMEZEOggmRLq8sOgkVD/6rEYlwtrmZHb/4ZbsJhakjR7pblIhEnWDGRCYbY4J685C11ot/DEUiXNnzL1C9eTMA2dOmkTNzhssViUg0CiZEFgC399DypK1tIdbUKWPMw32x33h0eOs29vz+9wCk5OUx4qZvaRxERBzp8XaWtXYVgVtUxpgrjDErrbU13Xz+vjDWR+B7r8Q/7iIhaqqvZ/t997c2Vhz1o5vVmVdEHOtVK/hwvC63twLjLBImex97nPp9ewE44StfJm3sWJcrEpFo1u3trHBNHAxxPzPRYH1YVG/cRNlLLwGQftppDJt/ucsViUi062lMZFOgBXwwYyEdGGMyjTE343Cw3RgzEdjsZFtpz1dTw47//m8AEgelMvL738MkJrpclYhEu25vZwUe4b3NGHNvYE7Icmvtaz3t1BgzA//M9a3W2vtDqK/AWrtCt7RCY61l568forHK/zjvSQu/SUpenstViUgsCOrNhoEZ67fif2/6ssBLqv4aaH1yT+B/X21ZDnwB/6tyHQeIMWamtXZFb7apra3FGNPpr6KiIqelRL2qNWupWrcOAM9ZU8iZocd5RaS9oqKiTv/uLC0tBUjvarseW8F3uaExowAPkA1UAl5r7XZHO+u4bw9QGHgyrOXnTdba0d1so1bwnWj0evnopm/jq60lKSOD0379K5Iz9bp7EQlOT63ge/V0VlvhCowuLITWMRGAHCDbGHML/l5eGmgP0q7fFOOrrQX8s9IVICISTo5DpC9Za5e0/dkYUwBcefxy6V7V2nVUrVsLQNaUKWSdc47LFYlIrAlqTCQYxpjfGGPmh2t/EhpfTQ07H/oNAElp6Yz41o2alS4iYRe2EAEWA2cEBtxHhmungdnqi4ECY8ziNre4pBt7fvcYvppqAE66YSHJWVkuVyQisShst7MCYyS3BgbB7w28m/0fYdhva5dgCU7Ne+9TsXo1AJmFhWSfN83likQkVoXzSgTwd/K11t6I/z0j0s+aGxrY9atfA5CQksKIG2/QbSwR6TPhHBPZEJgzMt8Yk2GtvU9jJP2vdNly6kv9jzmf8OUva1KhiPSpHm9nGWMOAiX4+1cdBIq76OK7DH+LklnAjcaYLPzzR54OX7nSnSO7d7P/KX+PzNSCAvIunedyRSIS64IZE9lurZ3c04fatIBf3bLMGPMFp4VJ71hr2fXQw/4W7ybB/670pIh8gltEYkgwt7NKnO7cWvuu022ld6rWrqX2g/cByJ17EYNPOcXlikQkHgQTIpV9XoWEpOnIEfY8+jsAkjIyOeGrX3G5IhGJF8GEiLPmWtJvSv+8lIbKgwCceN21JKWluVyRiMSLYELkamPMp4FOvZf39G6RcL3ISoJzZPceDjz3PABpY8eRM0OHX0T6TzAhshF4F7gaeAqo6iFUrgp3kcGqq6ujsLCQwsJCiouL3SqjX+159FFskw8wnHTjDZiEsE/9EZE4VFxcTGFhIRUVFQCpXX2ux1bwxph7A+8TIfBiqln4X1k7EyjAf7urBFiJ/zHgG6y1c8LxH9Eb8dgKvnrjJj69804AcufM4eT/+LbLFYlIrAlHK/jW3uGBNx22tiE5LlRmAzeiMZR+0ezzsfu3/sH0xNRUDaaLiCuCCZGCrlZ0Eioe/Fcj0scqXvkr9Xt2A5C/YAHJHo+7BYlIXArmBvpkY8z4YHZmrfXiH0ORPuQ7dIh9f/oTACl5eQzVzHQRcUkwIbIAuN0Yc3NPT2YFbAuxJulB6dJlrW8rPPG660hITna5IhGJVz3ezgq8irblXedXGGNWdtE7q+Xz93W1TkJ3tKyMshdfBCDts5/Fc/ZUlysSkXjWq+ZK1tqn+qoQCc6+J/6E9fkAOPHr16nNu4i4SpMKokjdjh0cfO11wP/O9LRx41yuSETinUIkiuz9/R8Bi0lIZPg1X3O7HBERhUi0qP3wQ6o3bgAgZ+YMBp54ossViYgoRKKCtZa9j/8BgIQBKZzw5S+5XJGIiJ9CJArUbNrMoY+3ADB03iUMyMlxuSIRET+FSISz1rLvCf/EwsRBqQy7Qq+tF5HIoRCJcNXvrOfw1k8ByPvipSSlp7tckYjIMTEVIrHWCt42Nx+7CkkdzNDLvuhyRSISL8LWCj5axGIr+Kq169i2eDEAw7/yFfL/7WqXKxKReNNTK/iYuhKJJba5mX1PPglAUlq6miyKSERSiEQo71tvc2TXLgDy5l9OYmqXV5MiIq5RiEQgay2ly5YBgauQi+e6XJGISOcUIhGoZuMm6kpKABh66TxdhYhIxFKIRBhrLaVL/VchiampGgsRkYimEIkwte+/f2x2+ty5JA0e7HJFIiJdU4hEmNKlywF/jyzNCxGRSKcQiSCHtmyh9oP3Aci96EKSMzNdrkhEpHsKkQhy4OlnADBJSeRdfpm7xYiIBEEhEiHq9+6l6q23Acg57zx16hWRqKAQiRAHnn0O8LegyZt/ubvFiIgESSESARqrqji4ajUAnsmTGTRihMsViYgERyESAcpe/AvNvkYA8vS+EBGJIjEVItHYCr7pyBHK//ISAGljx5H22c+6XJGIiFrBR42yF15kVyDwRv/4NrKmTHG5IhGRY9QKPoLZ5mYOPP8CACnD8vGccYbLFYmI9I5CxEXVGzZwdH8pAHmXzsMkJrpckYhI7yhEXHTgOf9VSGJqKjkzprtcjYhI7ylEXFJXUtLa4mTI7Nlq9y4iUUkh4pIDzz0PgDEJDL3kYperERFxRiHigsbKKirfeAMAz9QppOTluVyRiIgzChEXlL/yCtbnAyDv0ktdrkZExLkktwvoijHGAywAPMBowGutXeRmTeHQ7PNR/sorAKSOHsPgU8e5XJGIiHMRGyLAAmtt67RzY8xiY8xKa+0sN4sKlXfdmzRWVQEw9JKLMca4XJGIiHMRGSLGmAL8VyBt3QNUGWM81lpvvxcVJmUvvAhAUkYG2dPOdbkaEZHQRPKYyG1tf2gTHAX9X0p4HN66tfX96blz5pAwYIDLFYmIhCYir0SstSVAVttlgasTrLWbXSkqDFquQkxCIrlzL3K5GhGR0EXylcjxFgFLuvtAbW0txphOfxUVFfVPlV1orK6m6o01AHimnMWAIUNcrUdEpK2ioqJO/+4sLS0FSO9qu6jo4muMmQgs7m5QPdK7+JYuW87eP/4RgLF330365053uSIRkZ7FShff26L5qSzr81H+8ssADDp5JGmnn+ZyRSIi4RHxIWKMWQxc73YdofBu2EiD/8UuDL14rh7rFZGYEdEhYoy5Bbin5cksY4ynZYA9mpT/5S8AJKYOJvv881yuRkQkfCI2RIwxM4EVx80JWQBUulORM0d276HmvfcAyJkxncRBg1yuSEQkfCLyEd/A1cbKwO/brvK2ncUeDcpfeqn190MvnutiJSIi4ReRIRKYJxL1AwdNR45w8LXXAMiYMIGBw4e7XJGISHhF7O2sWFD5+t9oqqsDYOjFemeIiMQehUgfsdZSFriVNSA3l8zJhS5XJCISfgqRPnLogw85snMnALkXXYRJTHS5IhGR8FOI9JGyF/19shKSkhkye7bL1YiI9A2FSB84WlaO9+13AMg+bxrJmRkuVyQi0jcUIn2g/KWXsLYZgNxLLnG5GhGRvqMQCbPmhgYqXn0VgLRx4xg8ZrTLFYmI9J2YCpG6ujoKCwspLCykuNidOYmVb7yBr7YWgKHz5rlSg4hIqIqLiyksLKTC3/cvtavPRUUr+GBEQit4ay3/+v4PqNu2jeSsbD73u0dJSIrI+ZwiIkGJlVbwUaH2/Q+o27YNgNyLLlSAiEjMU4iE0YFnnwUgIXkAuXPVJ0tEYp9CJEyO7NxJ9caNAOTMnKHHekUkLihEwuTAM88GfmfIu+yLbpYiItJvFCJh0FBZycG//x2ArClnMfCETsefRERijkIkDMqefwHr8wGQN/9yl6sREek/CpEQ+Q4fpvzlVwBIO/VU0saNc7kiEZH+oxAJ0YFnnqWp7jAAw+bPd7kaEZH+pRAJQaPXy4FnnwNg8CmnkHnmGS5XJCLSvxQiIShdtpzmo/UADL/ma8e/D15EJOYpRBw6WlZG+csvA5Dx+fFkTJjgbkEiIi5QiDhU+uSfW5/IOuGar7pcjYiIOxQiDtSVlHBw9WsAeM48k7SxY12uSETEHTEVIv3RCr65sZHtD/wCa5sxCYkM/6quQkQk9qgVfB/Z89jj7H/6aQBO+NKXOOHLX+rT7xMRcZNawYdR7Uf/ZP/TzwCQOnoMwxZc5XJFIiLuUogEyXf4MDt++UvAkpCUzKgffl/vCxGRuKcQCUJjdTWf3P4Tju7fD8Dw665l0EknuVyViIj79E/pHjRUVPDJT/8v9Xv2AOA5awpD513iclUiIpFBIdIF29xM9cZN7PrNwzSUlwGQc8EFjPzudzAJuoATEQGFSDtNR47QUFbGoS1bOPDs89Tv2d26bujFF3PSwusVICIibShEgC2LbqN+9y58tbUd1iUOSiX/6gXkzb9cvbFERI6jEAF81d4OATIgZwhDL53HkDmzSRo82KXKREQim0IEyJo6FV9NDQPyhjIgdygpw/JIHTNGj/CKiPRAf0vib+MuIiK9p1FiERFxTCEiIiKOKURERMSxmAoRp63gi4qK+q6oGKbj5pyOnTM6bs719tipFXzvtiVWjkN/0nFzTsfOGR0355weO7WCFxGRPqMQERERxxQiYRDKq3hDfY2v29u7+d3RfNzd/O5oPu6hiub/drePXZestTHxC9iXn59vnfAfBucmTZrkyrZub+/mcQt1e7ePeyjHzu3a3TzuOuecc3rs8vPzLbDPdvF3bywNrDcmJCQk5eXl9Xrb0tJS8vPzHX93RUUFQ4YM6fdt3d7ezeMW6vZuH/dQjp3btbt53HXO9f+xO3DgAM3NzT5rbXJn62MpROqAZKDcwebpQMcWvsFLBepc2Nbt7d08bqFu7/ZxD+XYuV27m8dd55xzTo9dLtBore30Md+YCREREel/GlgXERHHFCIiIuKYQkRERBxTiIiIiGNx/VIqY8xCoDLwY7a1NkJn80QOY8yVwCxgcWDRlcBma+0q96qKTMYYD7AQGG2tvaGT9Tr/OtHdcdP5173AsVsAeIDRgNdau+i4z4T1vIvbK5HAgSyx1q6w1q4AKgPLpHvZQCGwDViJ/yTVH+DjGGMmAjMBbxfrdf51oqfjhs6/niyw1hZba5e0BLAxZmXLyr447+L2EV9jzCZr7aSelkl7xpgrAyefBKHlX86d/Ita5183ujluOv+6YIwpAK601i5ps8wDVAFZ1lpvX5x3cXklEjiwBZ2smhhYJ9JndP5JH7qt7Q/WWm/gtwV9dd7F65hIIcfuCbblxX+QN/drNVHGGDOzzY8T2/7LR4Ki8y8EOv86Z60tAbLaLgtcnWCt3Rw4bmE/7+I1RDx0fs+1Ev89V+laCf770CXg/1e1MWbx8YN30i0POv+c0vnXO4uAlpD10AfnXVzezhLnrLWbW/4AB35eAdziYkkSR3T+BS/wkEJBXwdsvIaIF38qHy+bzi/3pActl80SFC86/8JK51+nbrPWzmrzs5c+OO/iNUQ20vnlmwf/5bJ0whhTYIzZ1skqb3/XEuV0/jmg8y94xpjFwPXHLe6T8y4uQyTwxEJnyVvS5mkG6dzitj8EnurwtL3FIN3T+RcSnX89MMbcAtzTci4Fxo0K+uq8i8sQCXi47SSbwO8Xd/P5uNfFH9TbgA6zsaWdzv71p/OvZ+2Om86/ngWewFpxXCgs4Fh4hP28i9vJhnBs9ib+yzm1nQhCm7YKEHjaQ8eto5aJX8DV+B+fLAZWtp1drfOvo56Om86/rgWOXae3+6y1WW0+F9bzLq5DREREQhPPt7NERCREChEREXFMISIiIo4pRERExDGFiIiIOKYQERERxxQiIiLimEJEREQcU4iIiIhjChEREXFMISIiIo4pRERExDGFiEgfMsZUGWOWB1p0u1XD4kANy92qQWKXuvhK3Av8Bb8c/5vfSvC/KW8iUAgsC3wsG39r8oK2bbWD2Pdya+1VYapzIvAIgLV2koPtw1aLSIsktwsQiQA3ADOstZtbFgT+1V5prW33wiNjzKb+Lq6FtXazMeZ6/IEnEhF0O0sENrQNkICZwKpOPtvZsv7kdfn7RdpRiEhcM8Zcif/teW2XefC/9W1lJ5sc7PuqRKKHQkTi3arj3kcN/qsQ6PyqQ69iFWlDYyIS1zoJEIBZQEln67r4fK8Ern7AP1jvsdYuabNuIrA48OMi/IP5NwBXtf3uwOda9jERWGGtLQm1NpHe0pWISEddjYeEzBhzC7DZWrvCWlvcZhngHzzHHx7ZQLa1dgUdb6sVAF5r7WZr7apACC03xhT0Rc0i3VGIiLQRGA8poPPxkHAYDVzZ5udV+K982vICE621qwCstUuOuwIq6eSqYynHrmBE+o1uZ4m01914SMjaPjIcuHIoxH/Vcbze3praDNwWQmkijuhKRKS9LsdDwsEY4zHGPGyMWRhYpHEMiWoKEZH2uh0PMcbcYozZZoxZ2WZZb1qabAcettYWB25JVbbZTyhjGhNxfw6LxCGFiEhAT+MhgbAosdaOBha1ecoq2P0X4H8aq+3ExgL8c1LAHwTByO4kcK7GPyDfcrWzsONmIuGn3lkS9wJPR43GPz7R8i/6EmB5y+B24HMFbQe0W65A2n6mk32361cV+K4c/EFVGfie24Bt+Ht3Efj5SmAJsPS4diwe/FdL3sD2HR7xbekFdnyPL/XOkr6gEBFxyEmIuCmSapHYodtZIs55aDOmIRKPFCIizmV30rhRJK4oRESc01WIxD2FiIhznU0S7CBS3mwIbHCrBoldGlgXcSDwlFS2mh5KvFOIiIiIY7qdJSIijilERETEMYWIiIg4phARERHHFCIiIuKYQkRERBz7/4dCfncVf+W1AAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 443.077x360 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"T_gamma = torch.linspace(0.1, 20, 100)\n", | |
"plt.plot(T_gamma, T_nu(T_gamma))\n", | |
"plt.xlabel(r\"$T_\\gamma$\\,[arb.]\")\n", | |
"plt.ylabel(r\"$T_\\nu(T_\\gamma)$\\,[arb.]\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e03d8360-f04c-42da-84da-009200016d96", | |
"metadata": {}, | |
"source": [ | |
"Draw samples from latent function." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"id": "5da3258b-2159-4155-90a0-f0a792ca5dc9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Define mean and variance of observed sample\n", | |
"loc = torch.Tensor(obs(T_nu(T_gamma), T_gamma))\n", | |
"scale = 1.\n", | |
"\n", | |
"# Draw sample of observations\n", | |
"obs_sample = torch.distributions.Normal(loc=loc, scale=scale).sample()\n", | |
"\n", | |
"gpr = GPRegressionBBN(obs_sample, num_inducing=10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "867f9283-65c6-4240-af52-c4828c17c0f1", | |
"metadata": {}, | |
"source": [ | |
"Train with SGD." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "bb1d408d-73eb-4591-a6af-6172f87ccb66", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "bae78e510516411795c9ba370de24e66", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/1000 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"num_particles = 4\n", | |
"num_iter = 1000\n", | |
"\n", | |
"def train():\n", | |
" pyro.clear_param_store()\n", | |
" optimizer = pyro.optim.Adam({\"lr\": 0.1})\n", | |
" elbo = pyro.infer.Trace_ELBO(num_particles=num_particles, vectorize_particles=True, retain_graph=True)\n", | |
" svi = pyro.infer.SVI(gpr.model, gpr.guide, optimizer, elbo)\n", | |
"\n", | |
" gpr.train()\n", | |
" iterator = tqdm(range(num_iter))\n", | |
" for i in iterator:\n", | |
" gpr.zero_grad()\n", | |
" loss = svi.step(T_gamma)\n", | |
" iterator.set_postfix(loss=loss, lengthscale=gpr.covar_module.base_kernel.lengthscale.item())\n", | |
"\n", | |
"train()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f71e9da4-9f9f-4efb-a1c9-209f13a74f16", | |
"metadata": {}, | |
"source": [ | |
"Draw samples from the conditioned GP." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"id": "0c72e29d-e968-47be-bf84-51476f224dba", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"gpr.eval()\n", | |
"with torch.no_grad():\n", | |
" gpr_samples = gpr(T_gamma)(torch.Size([500])).abs().detach().numpy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "f689725f-fa42-4e01-85ac-fa294cb548e6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.legend.Legend at 0x7fdcb93b4400>" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 443.077x360 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.fill_between(T_gamma.detach().numpy(), np.percentile(gpr_samples,[16], axis=0)[0], np.percentile(gpr_samples, [84], axis=0)[0], alpha=0.3, color=cols_default[1], label=\"GP post.\")\n", | |
"plt.fill_between(T_gamma.detach().numpy(), np.percentile(gpr_samples,[2.5], axis=0)[0], np.percentile(gpr_samples, [97.5], axis=0)[0], alpha=0.1, color=cols_default[1])\n", | |
"plt.plot(T_gamma, T_nu(T_gamma), label=\"True\")\n", | |
"plt.xlabel(r\"$T_\\gamma$\\,[arb.]\")\n", | |
"plt.ylabel(r\"$T_\\nu(T_\\gamma)$\\,[arb.]\")\n", | |
"plt.legend(loc='upper left')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9f55affa-c4f9-4fbd-949c-e83848ae1029", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9ff531ba-f0ca-4386-8423-60a86e2debc8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment