Created
November 14, 2024 10:44
-
-
Save fehiepsi/b7def6a77bf9ca150cf2f17f2ba1a2b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from functools import partial\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import jax.random as random\n", | |
"import flax\n", | |
"import flax.linen as nn\n", | |
"import optax\n", | |
"import numpyro\n", | |
"import numpyro.distributions as dist\n", | |
"from numpyro.ops.indexing import Vindex\n", | |
"import coix" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"numpyro.set_platform(\"cpu\")\n", | |
"coix.set_backend(\"coix.numpyro\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class LSTM_MDN(nn.Module):\n", | |
" n_mixture_components: int\n", | |
" n_features: int\n", | |
"\n", | |
" @nn.compact\n", | |
" def __call__(self, z_prev, x_curr, carry):\n", | |
" x = jnp.stack([z_prev, x_curr], axis=-1)\n", | |
" lstm_cell = nn.LSTMCell(name=\"lstm_cell\", features=self.n_features)\n", | |
" carry, x = lstm_cell(carry, x)\n", | |
" mu_t = nn.Dense(self.n_mixture_components)(x)\n", | |
" log_sigma_t = nn.Dense(self.n_mixture_components)(x)\n", | |
" pi_t = nn.Dense(self.n_mixture_components)(x)\n", | |
" return mu_t, jnp.exp(log_sigma_t), nn.softmax(pi_t), carry" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"lstm_mdn = LSTM_MDN(n_mixture_components=3, n_features=50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def ssm_proposal(proposal, t, inputs):\n", | |
" if isinstance(t, int) and (t == 0):\n", | |
" z_t = numpyro.sample(\"z\", dist.Normal(0, 5))\n", | |
" else:\n", | |
" mu_t, sigma_t, pi_t, carry = jax.vmap(proposal, in_axes=(0, None, 0))(\n", | |
" inputs[\"zs\"][..., t - 1], inputs[\"xs\"][t], inputs[\"carry\"])\n", | |
" inputs[\"carry\"] = carry\n", | |
" z_t = numpyro.sample(\"z\", dist.MixtureSameFamily(\n", | |
" dist.Categorical(pi_t), dist.Normal(mu_t, sigma_t)))\n", | |
" return (inputs,)\n", | |
"\n", | |
"def f(z, t):\n", | |
" return z / 2 + 25 * z / (1 + z ** 2) + 8 * jnp.cos(1.2 * t)\n", | |
"\n", | |
"def g(z):\n", | |
" return z ** 2 / 20\n", | |
"\n", | |
"def ssm_target(t, inputs, simulate=False):\n", | |
" z_t_loc = jnp.where(t == 0, 0, f(inputs[\"zs\"][..., t - 1], t))\n", | |
" z_t_scale = jnp.where(t == 0, 5, jnp.sqrt(10))\n", | |
" z_t = numpyro.sample(\"z\", dist.Normal(z_t_loc, z_t_scale))\n", | |
" x_t = numpyro.sample(\"x\", dist.Normal(g(z_t), 1),\n", | |
" obs=None if simulate else inputs[\"xs\"][t])\n", | |
" inputs[\"zs\"] = inputs[\"zs\"].at[..., t].set(z_t)\n", | |
" inputs[\"xs\"] = inputs[\"xs\"].at[t].set(x_t)\n", | |
" return (inputs,)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def nasmc(targets, proposals, *, num_targets, unroll=False):\n", | |
" def body_fun(i, q):\n", | |
" p, q = targets(i), coix.compose(coix.detach(proposals(i)), coix.resample(q))\n", | |
" return coix.propose(p, q, loss_fn=coix.loss.rws_loss, chain=True)\n", | |
"\n", | |
" q = coix.propose(\n", | |
" targets(0), coix.detach(proposals(0)), loss_fn=coix.loss.rws_loss, chain=True)\n", | |
" if unroll:\n", | |
" for i in range(1, num_targets):\n", | |
" q = body_fun(i, q)\n", | |
" return q\n", | |
" else:\n", | |
" return coix.fori_loop(1, num_targets, body_fun, q)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_ssm(params, num_particles=10, T_max=1000):\n", | |
" network = coix.util.BindModule(lstm_mdn, params)\n", | |
" make_particle_plate = lambda: numpyro.plate(\"particle\", num_particles, dim=-1)\n", | |
" targets = lambda t: make_particle_plate()(partial(ssm_target, t))\n", | |
" proposals = lambda t: make_particle_plate()(partial(ssm_proposal, network, t))\n", | |
" program = nasmc(targets, proposals, num_targets=T_max, unroll=False)\n", | |
" return program" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def simulate_ssm(T_max=1000):\n", | |
" program = partial(ssm_target, 0, simulate=True)\n", | |
" for t in range(1, T_max):\n", | |
" program = coix.compose(partial(ssm_target, t, simulate=True), program)\n", | |
" return program" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def loss_fn(params, key, init_carry, num_particles=10, T_max=1000):\n", | |
" data_key, rng_key = random.split(key)\n", | |
" data = numpyro.handlers.seed(simulate_ssm(T_max=T_max), rng_seed=data_key)(\n", | |
" inputs={\"xs\": jnp.zeros(T_max), \"zs\": jnp.zeros(T_max)})[0]\n", | |
" assert data[\"xs\"].shape[0] == data[\"zs\"].shape[0] == T_max\n", | |
"\n", | |
" carry = jax.tree.map(lambda x: jnp.repeat(x[None], num_particles, axis=0), init_carry)\n", | |
" inputs = {\"zs\": jnp.zeros((num_particles, T_max)), \"xs\": data[\"xs\"], \"carry\": carry}\n", | |
" program = make_ssm(params, num_particles=num_particles, T_max=T_max)\n", | |
" _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(inputs)\n", | |
" return metrics[\"loss\"], metrics\n", | |
"\n", | |
"def batch_loss_fn(params, key, init_carry, batch_size=10, num_particles=10, T_max=1000):\n", | |
" if batch_size == 1:\n", | |
" return loss_fn(params, key, init_carry, num_particles=num_particles, T_max=T_max)\n", | |
" loss, metrics = jax.vmap(\n", | |
" partial(loss_fn, num_particles=num_particles, T_max=T_max),\n", | |
" in_axes=(None, 0, None))(params, jax.random.split(key, batch_size), init_carry)\n", | |
" return loss.mean(), jax.tree.map(jnp.mean, metrics)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Compiling the first train step...\n", | |
"Time to compile a train step: 8.777919054031372\n", | |
"=====\n", | |
"Step 50 | ess 2.4568 | log_Z -187.3532 | log_density -17.5152 | log_weight -196.0697 | loss 385.9447 | squared_grad_norm 460799.5938\n", | |
"Step 100 | ess 14.6745 | log_Z -109.4901 | log_density -11.9905 | log_weight -116.3763 | loss 312.4891 | squared_grad_norm 100631.6953\n", | |
"Step 150 | ess 14.1414 | log_Z -107.8167 | log_density -13.8655 | log_weight -116.1320 | loss 287.2689 | squared_grad_norm 62360.5312\n", | |
"Step 200 | ess 7.5109 | log_Z -148.9827 | log_density -22.9591 | log_weight -165.1962 | loss 264.5852 | squared_grad_norm 34423.0000\n", | |
"Step 250 | ess 10.4280 | log_Z -117.4725 | log_density -16.2886 | log_weight -127.2894 | loss 251.7576 | squared_grad_norm 9126.8506\n", | |
"Step 300 | ess 13.8720 | log_Z -95.6815 | log_density -13.7289 | log_weight -104.8602 | loss 220.5114 | squared_grad_norm 722.7039\n", | |
"Step 350 | ess 14.2133 | log_Z -82.9974 | log_density -14.0265 | log_weight -91.2859 | loss 238.8261 | squared_grad_norm 823.7309\n", | |
"Step 400 | ess 11.1576 | log_Z -107.7981 | log_density -19.8433 | log_weight -121.2751 | loss 244.7899 | squared_grad_norm 797.5330\n", | |
"Step 450 | ess 20.9421 | log_Z -108.1516 | log_density -16.1086 | log_weight -117.6074 | loss 243.2255 | squared_grad_norm 785.7372\n", | |
"Step 500 | ess 13.3322 | log_Z -84.9444 | log_density -16.9784 | log_weight -96.3718 | loss 228.7092 | squared_grad_norm 429.4864\n", | |
"Step 550 | ess 11.4026 | log_Z -103.4894 | log_density -13.4355 | log_weight -111.0630 | loss 226.2614 | squared_grad_norm 275.1570\n", | |
"Step 600 | ess 17.1433 | log_Z -101.0487 | log_density -18.6916 | log_weight -113.2903 | loss 241.0913 | squared_grad_norm 408.4709\n", | |
"Step 650 | ess 15.1463 | log_Z -115.8731 | log_density -19.5854 | log_weight -128.9296 | loss 220.9007 | squared_grad_norm 333.0921\n", | |
"Step 700 | ess 14.2602 | log_Z -104.8662 | log_density -15.6090 | log_weight -114.3933 | loss 223.7206 | squared_grad_norm 320.8732\n", | |
"Step 750 | ess 11.3065 | log_Z -118.1048 | log_density -15.2429 | log_weight -127.5417 | loss 225.1173 | squared_grad_norm 580.1074\n", | |
"Step 800 | ess 13.7209 | log_Z -83.1169 | log_density -13.3460 | log_weight -90.7723 | loss 209.2690 | squared_grad_norm 357.6276\n", | |
"Step 850 | ess 10.1412 | log_Z -112.6723 | log_density -12.7744 | log_weight -119.0864 | loss 198.0578 | squared_grad_norm 178.7914\n", | |
"Step 900 | ess 11.0142 | log_Z -113.9407 | log_density -13.3378 | log_weight -121.3964 | loss 204.0146 | squared_grad_norm 201.4111\n", | |
"Step 950 | ess 14.5444 | log_Z -92.2560 | log_density -12.1194 | log_weight -98.4538 | loss 187.6152 | squared_grad_norm 1153.7579\n", | |
"Step 1000 | ess 17.6082 | log_Z -97.2455 | log_density -9.8393 | log_weight -101.0580 | loss 187.0929 | squared_grad_norm 572.7567\n" | |
] | |
} | |
], | |
"source": [ | |
"num_particles = 100\n", | |
"num_steps = 1000\n", | |
"batch_size = 10\n", | |
"T_max = 100\n", | |
"init_carry = nn.LSTMCell(features=50).initialize_carry(random.key(1), (10,))\n", | |
"init_params = lstm_mdn.init(random.key(2), z_prev=0., x_curr=0., carry=init_carry)\n", | |
"lstm_mdn_params, _ = coix.util.train(\n", | |
" partial(\n", | |
" batch_loss_fn,\n", | |
" init_carry=init_carry,\n", | |
" num_particles=num_particles,\n", | |
" T_max=T_max,\n", | |
" batch_size=batch_size,\n", | |
" ),\n", | |
" init_params,\n", | |
" optax.adam(3e-4),\n", | |
" num_steps=num_steps,\n", | |
" jit_compile=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eval_program(params, key, init_carry, num_particles=10, T_max=1000):\n", | |
" data_key, rng_key = random.split(key)\n", | |
" data = numpyro.handlers.seed(simulate_ssm(T_max=T_max), rng_seed=data_key)(\n", | |
" inputs={\"xs\": jnp.zeros(T_max), \"zs\": jnp.zeros(T_max)})[0]\n", | |
" assert data[\"xs\"].shape[0] == data[\"zs\"].shape[0] == T_max\n", | |
"\n", | |
" carry = jax.tree.map(lambda x: jnp.repeat(x[None], num_particles, axis=0), init_carry)\n", | |
" inputs = {\"zs\": jnp.zeros((num_particles, T_max)), \"xs\": data[\"xs\"], \"carry\": carry}\n", | |
" program = make_ssm(params, num_particles=num_particles, T_max=T_max)\n", | |
" out, _, metrics = coix.traced_evaluate(program, seed=rng_key)(inputs)\n", | |
" idx = metrics['log_weight'].argmax()\n", | |
" return data[\"zs\"], out[0][\"zs\"][idx], metrics" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"true_zs, pred_zs, metrics = eval_program(\n", | |
" lstm_mdn_params, random.key(3), init_carry, num_particles=100, T_max=1000)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.plot(pred_zs[-100:], color='r')\n", | |
"plt.plot(true_zs[-100:]);" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(11.857068, dtype=float32)" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jnp.sqrt(((pred_zs - true_zs) ** 2).mean())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(73.36286, dtype=float32)" | |
] | |
}, | |
"execution_count": 46, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"metrics[\"ess\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(-1862.2279, dtype=float32)" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"metrics[\"log_Z\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment