Created
January 30, 2025 23:30
-
-
Save avivajpeyi/bfa87f9d16c4c16b3580c2da3a33c20d to your computer and use it in GitHub Desktop.
testing_vae_continuous.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyPoBUofCojwEVnLA8NGeZEc", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/bfa87f9d16c4c16b3580c2da3a33c20d/testing_vae_continuous.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# VAE on 1d continuous data\n", | |
"\n", | |
"\n", | |
"Code taken from https://github.com/theorashid/jax-vae/blob/main/vae-continuous.py\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "PGF4cDjtgMFF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"! pip install dm-haiku tinygp" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mvDmNce3gyA-", | |
"outputId": "de9aa222-242e-41fd-e6da-df571c2602c9" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Collecting dm-haiku\n", | |
" Downloading dm_haiku-0.0.13-py3-none-any.whl.metadata (19 kB)\n", | |
"Collecting tinygp\n", | |
" Downloading tinygp-0.3.0-py3-none-any.whl.metadata (2.5 kB)\n", | |
"Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from dm-haiku) (1.4.0)\n", | |
"Collecting jmp>=0.0.2 (from dm-haiku)\n", | |
" Downloading jmp-0.0.4-py3-none-any.whl.metadata (8.9 kB)\n", | |
"Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.11/dist-packages (from dm-haiku) (1.26.4)\n", | |
"Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.11/dist-packages (from dm-haiku) (0.9.0)\n", | |
"Collecting equinox (from tinygp)\n", | |
" Downloading equinox-0.11.11-py3-none-any.whl.metadata (18 kB)\n", | |
"Requirement already satisfied: jax in /usr/local/lib/python3.11/dist-packages (from tinygp) (0.4.33)\n", | |
"Requirement already satisfied: jaxlib in /usr/local/lib/python3.11/dist-packages (from tinygp) (0.4.33)\n", | |
"Collecting jax (from tinygp)\n", | |
" Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)\n", | |
"Collecting jaxtyping>=0.2.20 (from equinox->tinygp)\n", | |
" Downloading jaxtyping-0.2.37-py3-none-any.whl.metadata (6.6 kB)\n", | |
"Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.11/dist-packages (from equinox->tinygp) (4.12.2)\n", | |
"Collecting jaxlib (from tinygp)\n", | |
" Downloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)\n", | |
"Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from jax->tinygp) (0.4.1)\n", | |
"Requirement already satisfied: opt_einsum in /usr/local/lib/python3.11/dist-packages (from jax->tinygp) (3.4.0)\n", | |
"Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.11/dist-packages (from jax->tinygp) (1.13.1)\n", | |
"Collecting wadler-lindig>=0.1.3 (from jaxtyping>=0.2.20->equinox->tinygp)\n", | |
" Downloading wadler_lindig-0.1.3-py3-none-any.whl.metadata (17 kB)\n", | |
"Downloading dm_haiku-0.0.13-py3-none-any.whl (373 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m373.9/373.9 kB\u001b[0m \u001b[31m17.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading tinygp-0.3.0-py3-none-any.whl (44 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.0/44.0 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading jmp-0.0.4-py3-none-any.whl (18 kB)\n", | |
"Downloading equinox-0.11.11-py3-none-any.whl (179 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.2/179.2 kB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading jax-0.5.0-py3-none-any.whl (2.3 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m51.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl (102.0 MB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.0/102.0 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading jaxtyping-0.2.37-py3-none-any.whl (56 kB)\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25hDownloading wadler_lindig-0.1.3-py3-none-any.whl (20 kB)\n", | |
"Installing collected packages: wadler-lindig, jmp, jaxtyping, jaxlib, dm-haiku, jax, equinox, tinygp\n", | |
" Attempting uninstall: jaxlib\n", | |
" Found existing installation: jaxlib 0.4.33\n", | |
" Uninstalling jaxlib-0.4.33:\n", | |
" Successfully uninstalled jaxlib-0.4.33\n", | |
" Attempting uninstall: jax\n", | |
" Found existing installation: jax 0.4.33\n", | |
" Uninstalling jax-0.4.33:\n", | |
" Successfully uninstalled jax-0.4.33\n", | |
"Successfully installed dm-haiku-0.0.13 equinox-0.11.11 jax-0.5.0 jaxlib-0.5.0 jaxtyping-0.2.37 jmp-0.0.4 tinygp-0.3.0 wadler-lindig-0.1.3\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\"\"\"Variational Autoencoder example on continuous Gaussian process priors.\"\"\"\n", | |
"\n", | |
"from typing import Iterator, NamedTuple, Sequence, Tuple, Type\n", | |
"\n", | |
"import haiku as hk\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np\n", | |
"import optax\n", | |
"from tinygp import GaussianProcess, kernels\n", | |
"\n", | |
"\n", | |
"class flags(NamedTuple):\n", | |
" train_size: int = 16000\n", | |
" test_size: int = 4000\n", | |
" batch_size: int = 128\n", | |
" learning_rate: float = 0.001\n", | |
"\n", | |
" training_steps: int = 10000\n", | |
" eval_frequency: int = 100\n", | |
" random_seed: int = 42\n", | |
" alsologtostderr: bool = True\n", | |
"\n", | |
"\n", | |
"FLAGS = flags()\n", | |
"\n", | |
"\n", | |
"\n", | |
"PRNGKey = jnp.ndarray\n", | |
"Batch = Type[jnp.ndarray]\n", | |
"\n", | |
"SAMPLE_SHAPE: Sequence[int] = (100, 1)\n", | |
"\n", | |
"\n", | |
"def generate_gp_samples(\n", | |
" X: jnp.ndarray,\n", | |
" var: float,\n", | |
" scale: float,\n", | |
" num_draws: int,\n", | |
" batch_size: int,\n", | |
" sample_shape: Sequence[int] = SAMPLE_SHAPE,\n", | |
" seed: int = 1,\n", | |
") -> Iterator[Batch]:\n", | |
" kernel = var * kernels.ExpSquared(scale=scale)\n", | |
" gp = GaussianProcess(kernel, X)\n", | |
"\n", | |
" draws = gp.sample(\n", | |
" jax.random.PRNGKey(seed=seed),\n", | |
" shape=(\n", | |
" num_draws,\n", | |
" batch_size,\n", | |
" ),\n", | |
" )\n", | |
"\n", | |
" draws = jnp.reshape(draws, (-1, *(batch_size, *sample_shape)))\n", | |
"\n", | |
" return iter(draws)\n", | |
"\n", | |
"\n", | |
"class Encoder(hk.Module):\n", | |
" \"\"\"Encoder model.\"\"\"\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" hidden_size1: int = 50,\n", | |
" hidden_size2: int = 25,\n", | |
" latent_size: int = 10,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self._hidden_size1 = hidden_size1\n", | |
" self._hidden_size2 = hidden_size2\n", | |
" self._latent_size = latent_size\n", | |
" self.act = jax.nn.relu\n", | |
"\n", | |
" def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:\n", | |
" x = hk.Flatten()(x)\n", | |
" x = hk.Sequential(\n", | |
" [\n", | |
" hk.Linear(self._hidden_size1),\n", | |
" self.act,\n", | |
" hk.Linear(self._hidden_size2),\n", | |
" self.act,\n", | |
" ]\n", | |
" )(x)\n", | |
"\n", | |
" mean = hk.Linear(self._latent_size)(x)\n", | |
" log_stddev = hk.Linear(self._latent_size)(x)\n", | |
" stddev = jnp.exp(log_stddev)\n", | |
"\n", | |
" return mean, stddev\n", | |
"\n", | |
"\n", | |
"class Decoder(hk.Module):\n", | |
" \"\"\"Decoder model.\"\"\"\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" hidden_size1: int = 25,\n", | |
" hidden_size2: int = 50,\n", | |
" output_shape: Sequence[int] = SAMPLE_SHAPE,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self._hidden_size1 = hidden_size1\n", | |
" self._hidden_size2 = hidden_size2\n", | |
" self._output_shape = output_shape\n", | |
" self.act = jax.nn.relu\n", | |
"\n", | |
" def __call__(self, z: jnp.ndarray) -> jnp.ndarray:\n", | |
" output = hk.Sequential(\n", | |
" [\n", | |
" hk.Linear(self._hidden_size1),\n", | |
" self.act,\n", | |
" hk.Linear(self._hidden_size2),\n", | |
" self.act,\n", | |
" hk.Linear(np.prod(self._output_shape)),\n", | |
" ]\n", | |
" )(z)\n", | |
"\n", | |
" output = jnp.reshape(output, (-1, *self._output_shape))\n", | |
"\n", | |
" return output\n", | |
"\n", | |
"\n", | |
"class VAEOutput(NamedTuple):\n", | |
" mean: jnp.ndarray\n", | |
" stddev: jnp.ndarray\n", | |
" output: jnp.ndarray\n", | |
"\n", | |
"\n", | |
"class VariationalAutoEncoder(hk.Module):\n", | |
" \"\"\"Main VAE model class, uses Encoder & Decoder under the hood.\"\"\"\n", | |
"\n", | |
" def __init__(\n", | |
" self,\n", | |
" hidden_size1: int = 50,\n", | |
" hidden_size2: int = 25,\n", | |
" latent_size: int = 10,\n", | |
" output_shape: Sequence[int] = SAMPLE_SHAPE,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self._hidden_size1 = hidden_size1\n", | |
" self._hidden_size2 = hidden_size2\n", | |
" self._latent_size = latent_size\n", | |
" self._output_shape = output_shape\n", | |
"\n", | |
" def __call__(self, x: jnp.ndarray) -> VAEOutput:\n", | |
" x = x.astype(jnp.float32)\n", | |
" mean, stddev = Encoder(\n", | |
" self._hidden_size1, self._hidden_size2, self._latent_size\n", | |
" )(x)\n", | |
" z = mean + stddev * jax.random.normal(hk.next_rng_key(), mean.shape)\n", | |
" output = Decoder(\n", | |
" self._hidden_size2,\n", | |
" self._hidden_size1,\n", | |
" self._output_shape,\n", | |
" )(z)\n", | |
"\n", | |
" return VAEOutput(mean, stddev, output)\n", | |
"\n", | |
"\n", | |
"def mean_squared_error(x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray:\n", | |
" \"\"\"Calculate mean squared error between two tensors.\n", | |
"\n", | |
" Args:\n", | |
" x1: variable tensor\n", | |
" x2: variable tensor, must be of same shape as x1\n", | |
"\n", | |
" Returns:\n", | |
" A scalar representing mean square error for the two input tensors.\n", | |
" \"\"\"\n", | |
" if x1.shape != x2.shape:\n", | |
" raise ValueError(\"x1 and x2 must be of the same shape\")\n", | |
"\n", | |
" x1 = jnp.reshape(x1, (x1.shape[0], -1))\n", | |
" x2 = jnp.reshape(x2, (x2.shape[0], -1))\n", | |
"\n", | |
" return jnp.mean(jnp.square(x1 - x2), axis=-1)\n", | |
"\n", | |
"\n", | |
"def kl_gaussian(mean: jnp.ndarray, var: jnp.ndarray) -> jnp.ndarray:\n", | |
" r\"\"\"Calculate KL divergence between given and standard gaussian distributions.\n", | |
"\n", | |
" KL(p, q) = H(p, q) - H(p) = -\\int p(x)log(q(x))dx - -\\int p(x)log(p(x))dx\n", | |
" = 0.5 * [log(|s2|/|s1|) - 1 + tr(s1/s2) + (m1-m2)^2/s2]\n", | |
" = 0.5 * [-log(|s1|) - 1 + tr(s1) + m1^2] (if m2 = 0, s2 = 1)\n", | |
"\n", | |
" Args:\n", | |
" mean: mean vector of the first distribution\n", | |
" var: diagonal vector of covariance matrix of the first distribution\n", | |
"\n", | |
" Returns:\n", | |
" A scalar representing KL divergence of the two Gaussian distributions.\n", | |
" \"\"\"\n", | |
" return 0.5 * jnp.sum(-jnp.log(var) - 1.0 + var + jnp.square(mean), axis=-1)\n", | |
"\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "21PP8Q9DgTFi" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"## MAIN\n", | |
"\n", | |
"model = hk.transform(\n", | |
" lambda x: VariationalAutoEncoder()(x)\n", | |
") # pylint: disable=unnecessary-lambda\n", | |
"optimizer = optax.adam(FLAGS.learning_rate)\n", | |
"\n", | |
"val_losses = []\n", | |
"\n", | |
"@jax.jit\n", | |
"def loss_fn(\n", | |
" params: hk.Params,\n", | |
" rng_key: PRNGKey,\n", | |
" batch: Batch,\n", | |
") -> jnp.ndarray:\n", | |
" \"\"\"ELBO: E_p[log(x)] - KL(d||q), where p ~ Be(0.5) and q ~ N(0,1).\"\"\"\n", | |
" outputs: VAEOutput = model.apply(params, rng_key, batch)\n", | |
"\n", | |
" log_likelihood = -mean_squared_error(batch, outputs.output)\n", | |
" kl = kl_gaussian(outputs.mean, jnp.square(outputs.stddev))\n", | |
" elbo = log_likelihood - kl\n", | |
"\n", | |
" return -jnp.mean(elbo)\n", | |
"\n", | |
"@jax.jit\n", | |
"def update(\n", | |
" params: hk.Params,\n", | |
" rng_key: PRNGKey,\n", | |
" opt_state: optax.OptState,\n", | |
" batch: Batch,\n", | |
") -> Tuple[hk.Params, optax.OptState]:\n", | |
" \"\"\"Single SGD update step.\"\"\"\n", | |
" grads = jax.grad(loss_fn)(params, rng_key, batch)\n", | |
" updates, new_opt_state = optimizer.update(grads, opt_state)\n", | |
" new_params = optax.apply_updates(params, updates)\n", | |
" return new_params, new_opt_state\n", | |
"\n", | |
"rng_seq = hk.PRNGSequence(FLAGS.random_seed)\n", | |
"params = model.init(next(rng_seq), np.zeros((1, *SAMPLE_SHAPE)))\n", | |
"opt_state = optimizer.init(params)\n", | |
"\n", | |
"X = jnp.linspace(0, 10, 100)\n", | |
"\n", | |
"train_ds = generate_gp_samples(\n", | |
" X,\n", | |
" var=1.0,\n", | |
" scale=1.0,\n", | |
" num_draws=FLAGS.train_size,\n", | |
" batch_size=FLAGS.batch_size,\n", | |
")\n", | |
"valid_ds = generate_gp_samples(\n", | |
" X,\n", | |
" var=1.0,\n", | |
" scale=1.0,\n", | |
" num_draws=FLAGS.test_size,\n", | |
" batch_size=FLAGS.batch_size,\n", | |
")\n", | |
"\n", | |
"for step in range(FLAGS.training_steps):\n", | |
" params, opt_state = update(\n", | |
" params,\n", | |
" next(rng_seq),\n", | |
" opt_state,\n", | |
" next(train_ds),\n", | |
" )\n", | |
"\n", | |
" if step % FLAGS.eval_frequency == 0:\n", | |
" val_loss = loss_fn(params, next(rng_seq), next(valid_ds))\n", | |
" val_losses.append([step, val_loss])\n", | |
" print(f\"STEP: {step}; Validation ELBO: {-val_loss:.3f}\")\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "eUD5TbgEgtXU", | |
"outputId": "5d8c26dd-7d72-4140-c6ff-04ee4302fcb6" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"STEP: 0; Validation ELBO: -2.123\n", | |
"STEP: 100; Validation ELBO: -1.000\n", | |
"STEP: 200; Validation ELBO: -1.017\n", | |
"STEP: 300; Validation ELBO: -1.015\n", | |
"STEP: 400; Validation ELBO: -1.094\n", | |
"STEP: 500; Validation ELBO: -0.986\n", | |
"STEP: 600; Validation ELBO: -1.120\n", | |
"STEP: 700; Validation ELBO: -1.007\n", | |
"STEP: 800; Validation ELBO: -0.904\n", | |
"STEP: 900; Validation ELBO: -1.028\n", | |
"STEP: 1000; Validation ELBO: -1.011\n", | |
"STEP: 1100; Validation ELBO: -1.037\n", | |
"STEP: 1200; Validation ELBO: -0.992\n", | |
"STEP: 1300; Validation ELBO: -1.002\n", | |
"STEP: 1400; Validation ELBO: -0.976\n", | |
"STEP: 1500; Validation ELBO: -1.071\n", | |
"STEP: 1600; Validation ELBO: -1.028\n", | |
"STEP: 1700; Validation ELBO: -1.039\n", | |
"STEP: 1800; Validation ELBO: -0.927\n", | |
"STEP: 1900; Validation ELBO: -0.960\n", | |
"STEP: 2000; Validation ELBO: -0.990\n", | |
"STEP: 2100; Validation ELBO: -0.953\n", | |
"STEP: 2200; Validation ELBO: -1.076\n", | |
"STEP: 2300; Validation ELBO: -0.984\n", | |
"STEP: 2400; Validation ELBO: -0.966\n", | |
"STEP: 2500; Validation ELBO: -1.084\n", | |
"STEP: 2600; Validation ELBO: -0.977\n", | |
"STEP: 2700; Validation ELBO: -1.025\n", | |
"STEP: 2800; Validation ELBO: -0.975\n", | |
"STEP: 2900; Validation ELBO: -0.948\n", | |
"STEP: 3000; Validation ELBO: -1.046\n", | |
"STEP: 3100; Validation ELBO: -0.953\n", | |
"STEP: 3200; Validation ELBO: -0.963\n", | |
"STEP: 3300; Validation ELBO: -0.971\n", | |
"STEP: 3400; Validation ELBO: -0.998\n", | |
"STEP: 3500; Validation ELBO: -0.957\n", | |
"STEP: 3600; Validation ELBO: -1.029\n", | |
"STEP: 3700; Validation ELBO: -0.940\n", | |
"STEP: 3800; Validation ELBO: -0.936\n", | |
"STEP: 3900; Validation ELBO: -0.936\n", | |
"STEP: 4000; Validation ELBO: -1.013\n", | |
"STEP: 4100; Validation ELBO: -1.111\n", | |
"STEP: 4200; Validation ELBO: -0.948\n", | |
"STEP: 4300; Validation ELBO: -0.891\n", | |
"STEP: 4400; Validation ELBO: -0.975\n", | |
"STEP: 4500; Validation ELBO: -0.998\n", | |
"STEP: 4600; Validation ELBO: -0.935\n", | |
"STEP: 4700; Validation ELBO: -1.082\n", | |
"STEP: 4800; Validation ELBO: -1.068\n", | |
"STEP: 4900; Validation ELBO: -0.982\n", | |
"STEP: 5000; Validation ELBO: -1.067\n", | |
"STEP: 5100; Validation ELBO: -0.975\n", | |
"STEP: 5200; Validation ELBO: -0.988\n", | |
"STEP: 5300; Validation ELBO: -1.039\n", | |
"STEP: 5400; Validation ELBO: -0.959\n", | |
"STEP: 5500; Validation ELBO: -1.003\n", | |
"STEP: 5600; Validation ELBO: -0.961\n", | |
"STEP: 5700; Validation ELBO: -1.057\n", | |
"STEP: 5800; Validation ELBO: -0.971\n", | |
"STEP: 5900; Validation ELBO: -0.908\n", | |
"STEP: 6000; Validation ELBO: -1.001\n", | |
"STEP: 6100; Validation ELBO: -0.877\n", | |
"STEP: 6200; Validation ELBO: -0.914\n", | |
"STEP: 6300; Validation ELBO: -0.937\n", | |
"STEP: 6400; Validation ELBO: -0.964\n", | |
"STEP: 6500; Validation ELBO: -0.932\n", | |
"STEP: 6600; Validation ELBO: -1.000\n", | |
"STEP: 6700; Validation ELBO: -1.032\n", | |
"STEP: 6800; Validation ELBO: -1.022\n", | |
"STEP: 6900; Validation ELBO: -0.993\n", | |
"STEP: 7000; Validation ELBO: -1.069\n", | |
"STEP: 7100; Validation ELBO: -0.983\n", | |
"STEP: 7200; Validation ELBO: -1.052\n", | |
"STEP: 7300; Validation ELBO: -1.029\n", | |
"STEP: 7400; Validation ELBO: -0.990\n", | |
"STEP: 7500; Validation ELBO: -0.986\n", | |
"STEP: 7600; Validation ELBO: -1.111\n", | |
"STEP: 7700; Validation ELBO: -0.953\n", | |
"STEP: 7800; Validation ELBO: -0.997\n", | |
"STEP: 7900; Validation ELBO: -1.051\n", | |
"STEP: 8000; Validation ELBO: -0.986\n", | |
"STEP: 8100; Validation ELBO: -0.999\n", | |
"STEP: 8200; Validation ELBO: -0.910\n", | |
"STEP: 8300; Validation ELBO: -1.035\n", | |
"STEP: 8400; Validation ELBO: -0.993\n", | |
"STEP: 8500; Validation ELBO: -0.960\n", | |
"STEP: 8600; Validation ELBO: -1.011\n", | |
"STEP: 8700; Validation ELBO: -1.008\n", | |
"STEP: 8800; Validation ELBO: -1.010\n", | |
"STEP: 8900; Validation ELBO: -0.992\n", | |
"STEP: 9000; Validation ELBO: -1.072\n", | |
"STEP: 9100; Validation ELBO: -0.925\n", | |
"STEP: 9200; Validation ELBO: -1.018\n", | |
"STEP: 9300; Validation ELBO: -1.065\n", | |
"STEP: 9400; Validation ELBO: -0.987\n", | |
"STEP: 9500; Validation ELBO: -0.985\n", | |
"STEP: 9600; Validation ELBO: -1.096\n", | |
"STEP: 9700; Validation ELBO: -1.015\n", | |
"STEP: 9800; Validation ELBO: -0.962\n", | |
"STEP: 9900; Validation ELBO: -0.991\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"\n", | |
"def plot_losses(val_losses):\n", | |
" steps, losses = np.array(val_losses).T\n", | |
" plt.plot(steps, -losses)\n", | |
" plt.xlabel(\"Steps\")\n", | |
" plt.ylabel(\"Validataion ELBO\")\n", | |
" plt.show()\n", | |
"\n", | |
"\n", | |
"plot_losses(val_losses)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 449 | |
}, | |
"id": "1Qv9cFP0g72t", | |
"outputId": "07f76d5c-c6c3-49ea-e279-6791a3fafab1" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_h0VznNShWhu", | |
"outputId": "eab6917d-7cd8-4709-8b93-225f918cd3c9" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.00000000e+00, 1.00000000e+02, 2.00000000e+02, 3.00000000e+02,\n", | |
" 4.00000000e+02, 5.00000000e+02, 6.00000000e+02, 7.00000000e+02,\n", | |
" 8.00000000e+02, 9.00000000e+02, 1.00000000e+03, 1.10000000e+03,\n", | |
" 1.20000000e+03, 1.30000000e+03, 1.40000000e+03, 1.50000000e+03,\n", | |
" 1.60000000e+03, 1.70000000e+03, 1.80000000e+03, 1.90000000e+03,\n", | |
" 2.00000000e+03, 2.10000000e+03, 2.20000000e+03, 2.30000000e+03,\n", | |
" 2.40000000e+03, 2.50000000e+03, 2.60000000e+03, 2.70000000e+03,\n", | |
" 2.80000000e+03, 2.90000000e+03, 3.00000000e+03, 3.10000000e+03,\n", | |
" 3.20000000e+03, 3.30000000e+03, 3.40000000e+03, 3.50000000e+03,\n", | |
" 3.60000000e+03, 3.70000000e+03, 3.80000000e+03, 3.90000000e+03,\n", | |
" 4.00000000e+03, 4.10000000e+03, 4.20000000e+03, 4.30000000e+03,\n", | |
" 4.40000000e+03, 4.50000000e+03, 4.60000000e+03, 4.70000000e+03,\n", | |
" 4.80000000e+03, 4.90000000e+03, 5.00000000e+03, 5.10000000e+03,\n", | |
" 5.20000000e+03, 5.30000000e+03, 5.40000000e+03, 5.50000000e+03,\n", | |
" 5.60000000e+03, 5.70000000e+03, 5.80000000e+03, 5.90000000e+03,\n", | |
" 6.00000000e+03, 6.10000000e+03, 6.20000000e+03, 6.30000000e+03,\n", | |
" 6.40000000e+03, 6.50000000e+03, 6.60000000e+03, 6.70000000e+03,\n", | |
" 6.80000000e+03, 6.90000000e+03, 7.00000000e+03, 7.10000000e+03,\n", | |
" 7.20000000e+03, 7.30000000e+03, 7.40000000e+03, 7.50000000e+03,\n", | |
" 7.60000000e+03, 7.70000000e+03, 7.80000000e+03, 7.90000000e+03,\n", | |
" 8.00000000e+03, 8.10000000e+03, 8.20000000e+03, 8.30000000e+03,\n", | |
" 8.40000000e+03, 8.50000000e+03, 8.60000000e+03, 8.70000000e+03,\n", | |
" 8.80000000e+03, 8.90000000e+03, 9.00000000e+03, 9.10000000e+03,\n", | |
" 9.20000000e+03, 9.30000000e+03, 9.40000000e+03, 9.50000000e+03,\n", | |
" 9.60000000e+03, 9.70000000e+03, 9.80000000e+03, 9.90000000e+03],\n", | |
" [2.12302637e+00, 9.99525905e-01, 1.01709914e+00, 1.01476967e+00,\n", | |
" 1.09447312e+00, 9.86405730e-01, 1.11957288e+00, 1.00677228e+00,\n", | |
" 9.03819323e-01, 1.02800286e+00, 1.01082420e+00, 1.03712881e+00,\n", | |
" 9.92117643e-01, 1.00238907e+00, 9.76334572e-01, 1.07137012e+00,\n", | |
" 1.02780819e+00, 1.03860843e+00, 9.26625848e-01, 9.59651232e-01,\n", | |
" 9.90324736e-01, 9.52687860e-01, 1.07620740e+00, 9.84351337e-01,\n", | |
" 9.66414928e-01, 1.08449459e+00, 9.76938486e-01, 1.02538371e+00,\n", | |
" 9.75432575e-01, 9.47643876e-01, 1.04604673e+00, 9.53336358e-01,\n", | |
" 9.62800860e-01, 9.71372604e-01, 9.97625828e-01, 9.56852436e-01,\n", | |
" 1.02912521e+00, 9.40208673e-01, 9.35671449e-01, 9.36323881e-01,\n", | |
" 1.01344752e+00, 1.11106837e+00, 9.47826743e-01, 8.91282439e-01,\n", | |
" 9.75206256e-01, 9.97945905e-01, 9.35161948e-01, 1.08165908e+00,\n", | |
" 1.06789613e+00, 9.81801093e-01, 1.06721389e+00, 9.75085139e-01,\n", | |
" 9.88278747e-01, 1.03863251e+00, 9.58567619e-01, 1.00275397e+00,\n", | |
" 9.61375415e-01, 1.05719841e+00, 9.71031070e-01, 9.07814860e-01,\n", | |
" 1.00135756e+00, 8.77296329e-01, 9.14272428e-01, 9.36788321e-01,\n", | |
" 9.64334786e-01, 9.32047009e-01, 1.00033331e+00, 1.03161788e+00,\n", | |
" 1.02169251e+00, 9.92682040e-01, 1.06889462e+00, 9.82906163e-01,\n", | |
" 1.05177510e+00, 1.02857947e+00, 9.89680469e-01, 9.86090541e-01,\n", | |
" 1.11118913e+00, 9.53116179e-01, 9.97361183e-01, 1.05149746e+00,\n", | |
" 9.85543489e-01, 9.98808086e-01, 9.10445154e-01, 1.03497601e+00,\n", | |
" 9.93169904e-01, 9.60192680e-01, 1.01141238e+00, 1.00759196e+00,\n", | |
" 1.00983500e+00, 9.91928697e-01, 1.07226229e+00, 9.25239801e-01,\n", | |
" 1.01844692e+00, 1.06501722e+00, 9.87044334e-01, 9.84877408e-01,\n", | |
" 1.09594595e+00, 1.01509380e+00, 9.62336659e-01, 9.90623474e-01]])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def plot_reconstruction(xs, x_hats):\n", | |
" fig, axs = plt.subplots(3, 3, figsize=(5, 5))\n", | |
"\n", | |
" for i, ax in enumerate(axs.flatten()):\n", | |
" ax.plot(xs[i], label='True', lw=1, alpha=1, color='black')\n", | |
" ax.plot(x_hats[i], label='Reconstructed', lw=0.5, alpha=0.3, color='tab:red')\n", | |
" ax.axis(\"off\")\n", | |
" # remove all space between subplots\n", | |
" plt.suptitle(\"Reconstructions\")\n", | |
" plt.subplots_adjust(wspace=0, hspace=0)\n", | |
"\n", | |
"\n", | |
"valid_databatch = next(valid_ds)\n", | |
"model_outputs: VAEOutput = model.apply(params, next(rng_seq), valid_databatch)\n", | |
"plot_reconstruction(valid_databatch, model_outputs.output)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 472 | |
}, | |
"id": "6IqddkS9hWvN", | |
"outputId": "63cfeee4-66fd-435f-ac6d-a1307f2676e4" | |
}, | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 500x500 with 9 Axes>" | |
], | |
"image/png": "\n" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "p8_i2UecjJAB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment