Created
August 11, 2025 21:15
-
-
Save willb/abe6aa1ad4bfe7bf59e6cf5ef0bd7fcd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c8d4b5b3", | |
"metadata": {}, | |
"source": [ | |
"# Self-organizing maps in JAX\n", | |
"\n", | |
"In this notebook, we'll develop implementations of online and batch self-organizing map training in JAX, refining each as we go to get better performance. We'll start with the easiest option: simply using JAX as a drop-in replacement for numpy (more or less like we did with cuPy).\n", | |
"\n", | |
"## Accelerating NumPy functions with JAX" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b6a2e5d6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4daae423", | |
"metadata": {}, | |
"source": [ | |
"Recall that we're initializing our map with random vectors. The result of this function is a matrix with a row for every element in a self-organizing map; each row contains uniformly-sampled random numbers between 0 and 1.\n", | |
"\n", | |
"Because JAX uses a purely functional approach to random number generation, we'll need to rewrite this code from the numpy implementation -- instead of using a stateful generator like numpy's `Generator` or `RandomState`, we'll create a `PRNGKey` object and pass that to `jax.random.uniform`. (For this example, we're not doing anything with the key — for a real application, we'd want to _split_ it so we could get the next number in the seeded sequence.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "14584d07", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def init_som(xdim, ydim, fdim, seed):\n", | |
" key = jax.random.PRNGKey(seed)\n", | |
" return jnp.array(jax.random.uniform(key, shape=(xdim * ydim * fdim,)).reshape(xdim * ydim, fdim))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "590b8f4b", | |
"metadata": {}, | |
"source": [ | |
"We can see that JAX is not returning a numpy array:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a64bac02", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_size = 192\n", | |
"y_size = 108\n", | |
"feature_dims = 3\n", | |
"\n", | |
"random_map = init_som(x_size, y_size, feature_dims, 42)\n", | |
"type(random_map)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "43ca779b", | |
"metadata": {}, | |
"source": [ | |
"...and we should be able to see that this array is stored in GPU memory (if we're actually running on a GPU)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f6b015ce", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"random_map.device()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c552417c", | |
"metadata": {}, | |
"source": [ | |
"As before, you can visualize the result if you want — unlike cuPy, JAX will transfer arrays directly to device memory when needed by plotting libraries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f95d1235", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import plotly.express as px\n", | |
"import plotly.io as pio\n", | |
"pio.renderers.default='notebook'\n", | |
"\n", | |
"px.imshow(random_map.reshape(x_size, y_size, feature_dims).swapaxes(0,1))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dbfdb8b6", | |
"metadata": {}, | |
"source": [ | |
"Our neighborhood function is very similar to the numpy implementation; the only difference is that we need to change `np` to `jnp`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ddc781bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def neighborhood(range_x, range_y, center_x, center_y, x_sigma, y_sigma):\n", | |
" x_distance = jnp.abs(center_x - range_x)\n", | |
" x_neighborhood = jnp.exp(- jnp.square(x_distance) / jnp.square(x_sigma))\n", | |
"\n", | |
" y_distance = jnp.abs(center_y - range_y)\n", | |
" y_neighborhood = jnp.exp(- jnp.square(y_distance) / jnp.square(y_sigma))\n", | |
"\n", | |
" return jnp.outer(x_neighborhood, y_neighborhood)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "109383e2", | |
"metadata": {}, | |
"source": [ | |
"Plotting results is a good way to make sure that they look the way we expect them to." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ec73094f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"center_x = 12\n", | |
"center_y = 48\n", | |
"sigma_x = 96\n", | |
"sigma_y = 54\n", | |
"\n", | |
"px.imshow(neighborhood(np.arange(x_size), np.arange(y_size), center_x, center_y, sigma_x, sigma_y).T)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7c55b268", | |
"metadata": {}, | |
"source": [ | |
"We're now ready to see the basic online (i.e., one sample at a time) training algorithm. Most of it is unchanged from the numpy implementation, with a few key differences:\n", | |
"\n", | |
"1. The first differences are related to how we shuffle the example array. Because we aren't using a stateful random number generator, we'll need to split the random state key into two parts (one representing the key for the very next generation and one representing the key for the rest of the stream). We'll declare a little helper function that splits the key, shuffles the array, and returns both the key and the shuffled array.\n", | |
"2. The second difference relates to how JAX handles arrays. In JAX, arrays offer an _immutable_ interface: instead of changing an array directly, JAX's API lets you make a copy of the array with a change. (In practice, this does not always mean the array is actually copied!) This impacts our code because the numpy version used some functions with output parameters, which indicate where to write the return value (rather than merely returning a new array). So, instead of `np.add(a, b, a)`, we'd do `a = jnp.add(a, b)`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d2b1e9e3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def shuffle(key, examples):\n", | |
" key, nextkey = jax.random.split(key)\n", | |
" examples = jax.random.shuffle(nextkey, examples)\n", | |
" return key, examples\n", | |
"\n", | |
"def train_som_online(examples, xdim, ydim, x_sigma, y_sigma, max_iter, seed=42, frame_callback=None):\n", | |
" t = 0\n", | |
"\n", | |
" exs = examples.copy()\n", | |
" fdim = exs.shape[-1]\n", | |
"\n", | |
" x_sigmas = jnp.linspace(x_sigma, max(2, x_sigma*.05), max_iter)\n", | |
" y_sigmas = jnp.linspace(y_sigma, max(2, y_sigma*.05), max_iter)\n", | |
" alphas = jnp.geomspace(0.35, 0.01, max_iter)\n", | |
"\n", | |
" range_x, range_y = jnp.arange(xdim), jnp.arange(ydim)\n", | |
"\n", | |
" hood = None\n", | |
" som = init_som(xdim, ydim, fdim, seed)\n", | |
"\n", | |
" key = jax.random.PRNGKey(seed)\n", | |
" while t < max_iter:\n", | |
" key, exs = shuffle(key, exs)\n", | |
" for ex in exs:\n", | |
" t = t + 1\n", | |
" if t == max_iter:\n", | |
" break\n", | |
" \n", | |
" # best matching unit (by euclidean distance)\n", | |
" bmu_idx = jnp.argmin(jnp.linalg.norm(ex - som, axis = 1))\n", | |
"\n", | |
" bmu = som[bmu_idx]\n", | |
"\n", | |
" center_x = bmu_idx // xdim\n", | |
" center_y = bmu_idx % ydim\n", | |
" \n", | |
" hood = neighborhood(range_x, range_y, center_x, center_y, x_sigmas[t], y_sigmas[t]).reshape(-1, 1)\n", | |
"\n", | |
" update = jnp.multiply(((ex - som) * alphas[t]), hood)\n", | |
" \n", | |
" frame_callback and frame_callback(t - 1, ex, hood, som)\n", | |
"\n", | |
" # NB: JAX methods don't have output parameters\n", | |
" som = jnp.add(som, update)\n", | |
" \n", | |
" som = jnp.clip(som, 0, 1)\n", | |
" \n", | |
" frame_callback and frame_callback(t, ex, hood, som)\n", | |
" return som" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fe777ff3", | |
"metadata": {}, | |
"source": [ | |
"As before, we'll use the history callback class to track our progress." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "552a3143", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class HistoryCallback(object):\n", | |
" \n", | |
" def __init__(self, xdim, ydim, fdim, epoch_pred):\n", | |
" self.frames = dict()\n", | |
" self.meta = dict()\n", | |
" self.xdim = xdim\n", | |
" self.ydim = ydim\n", | |
" self.fdim = fdim\n", | |
" self.epoch_pred = epoch_pred\n", | |
" \n", | |
" def __call__(self, epoch, ex, hood, som, **meta):\n", | |
" if self.epoch_pred(epoch):\n", | |
" self.frames[epoch] = (ex, hood, som)\n", | |
" if meta is not None:\n", | |
" self.meta[epoch] = meta\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2982e4e3", | |
"metadata": {}, | |
"source": [ | |
"Here we'll train a small map on random color data, storing one history snapshot for every 20 examples." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8a78da40", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"fc = HistoryCallback(240,135,3, lambda x: x%20 == 0)\n", | |
"\n", | |
"examples = jax.random.uniform(jax.random.PRNGKey(42), shape=(1000, 3))\n", | |
"\n", | |
"color_som =\\\n", | |
" train_som_online(examples, \n", | |
" 240, 135, \n", | |
" 120, 70, \n", | |
" 50000, 42, fc)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6b51b600", | |
"metadata": {}, | |
"source": [ | |
"Depending on your computer, this may have actually been slower than the numpy version! Let's try using JAX's _just-in-time_ compilation to improve our performance. We'll make just-in-time compiled versions of our `neighborhood` and `shuffle` functions (as well as of the inner part of the training loop). We'll also add a progress bar." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2606210a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tqdm\n", | |
"\n", | |
"jit_neighborhood = jax.jit(neighborhood)\n", | |
"jit_shuffle = jax.jit(shuffle)\n", | |
"\n", | |
"@jax.jit\n", | |
"def som_step(ex, som, xdim, ydim, range_x, range_y, center_x, center_y, x_sigma, y_sigma, alpha):\n", | |
" # best matching unit (by euclidean distance)\n", | |
" bmu_idx = jnp.argmin(jnp.linalg.norm(ex - som, axis = 1))\n", | |
"\n", | |
" bmu = som[bmu_idx]\n", | |
"\n", | |
" center_x = bmu_idx // xdim\n", | |
" center_y = bmu_idx % ydim\n", | |
" \n", | |
" hood = jit_neighborhood(range_x, range_y, center_x, center_y, x_sigma, y_sigma).reshape(-1, 1)\n", | |
"\n", | |
" update = jnp.multiply(((ex - som) * alpha), hood)\n", | |
" \n", | |
" return jnp.clip(jnp.add(som, update), 0, 1)\n", | |
"\n", | |
"def train_som_online2(examples, xdim, ydim, x_sigma, y_sigma, max_iter, seed=42, frame_callback=None):\n", | |
" t = 0\n", | |
"\n", | |
" exs = examples.copy()\n", | |
" fdim = exs.shape[-1]\n", | |
"\n", | |
" x_sigmas = jnp.linspace(x_sigma, max(2, x_sigma*.05), max_iter)\n", | |
" y_sigmas = jnp.linspace(y_sigma, max(2, y_sigma*.05), max_iter)\n", | |
" alphas = jnp.geomspace(0.35, 0.01, max_iter)\n", | |
"\n", | |
" range_x, range_y = jnp.arange(xdim), jnp.arange(ydim)\n", | |
"\n", | |
" hood = None\n", | |
" som = init_som(xdim, ydim, fdim, seed)\n", | |
"\n", | |
" key = jax.random.PRNGKey(seed)\n", | |
" with tqdm.tqdm(total=max_iter) as progress:\n", | |
" while t < max_iter:\n", | |
" key, exs = jit_shuffle(key, exs)\n", | |
" for ex in exs:\n", | |
" t = t + 1\n", | |
" progress.update(1)\n", | |
" if t == max_iter:\n", | |
" break\n", | |
"\n", | |
" som = som_step(ex, som, xdim, ydim, range_x, range_y, center_x, center_y, x_sigmas[t], y_sigmas[t], alphas[t])\n", | |
" frame_callback and frame_callback(t, ex, hood, som)\n", | |
" \n", | |
" return som\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4a6b7dab", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"fc = HistoryCallback(240,135,3, lambda x: x%20 == 0)\n", | |
"\n", | |
"examples = jax.random.uniform(jax.random.PRNGKey(42), shape=(1000, 3))\n", | |
"\n", | |
"color_som =\\\n", | |
" train_som_online2(examples, \n", | |
" 240, 135, \n", | |
" 120, 70, \n", | |
" 50000, 42, fc)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "796bb384", | |
"metadata": {}, | |
"source": [ | |
"Let's check our final map to make sure it looks somewhat reasonable." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "28fee4ab", | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"px.imshow(color_som.reshape(240,135,3).swapaxes(0,1)).show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c1451ce3", | |
"metadata": {}, | |
"source": [ | |
"✅ One challenging aspect of the online algorithm is its sensitivity to hyperparameter settings:\n", | |
"* Try running the code again with some different values for `x_sigma` and `y_sigma` and see how your results change!\n", | |
"* The `alphas` variable (which we didn't expose as a parameter) indicates how much of an effect each example has on the map. We've set it to `jnp.geomspace(0.35, 0.01, max_iter)`; try some different values and see if you get better or worse results!\n", | |
"\n", | |
"Let's now consider the batch variant of the algorithm. It can be much faster, can be implemented in parallel (or even on a cluster) and is less sensitive to hyperparamter settings. In order to exploit additional parallelism, we're going to use `jax.vmap` to calculate weight updates for each training example in parallel. This should result in a dramatic performance improvement." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4aac8d10", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"aaaa = np.random.random(size=(128,3))\n", | |
"bbbb = np.random.random(size=(3,))\n", | |
"\n", | |
"jnp.dot(aaaa, bbbb) \n", | |
"\n", | |
"jnp.divide(jnp.dot(aaaa, bbbb), (jnp.linalg.norm(aaaa, axis=1) * jnp.linalg.norm(bbbb)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9e7d93e8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from functools import partial\n", | |
"\n", | |
"@partial(jax.vmap, in_axes=(0, None, None, None, None, None, None, None), out_axes=0)\n", | |
"def batch_step(ex, som, range_x, range_y, xdim, ydim, x_sigma, y_sigma):\n", | |
" # best matching unit (by Euclidean distance)\n", | |
" bmu_idx = jnp.argmin(jnp.linalg.norm(som - ex, axis=1))\n", | |
" bmu = som[bmu_idx]\n", | |
" \n", | |
" center_x = bmu_idx // xdim\n", | |
" center_y = bmu_idx % ydim\n", | |
" \n", | |
" hood = jit_neighborhood(range_x, range_y, center_x, center_y, x_sigma, y_sigma).reshape(-1, 1)\n", | |
"\n", | |
" return ex * hood\n", | |
"\n", | |
"def train_som_batch(examples, xdim, ydim, x_sigma, y_sigma, epochs, min_sigma_frac=0.1, seed=None, frame_callback=None):\n", | |
" t = 0\n", | |
"\n", | |
" exs = examples.copy()\n", | |
" fdim = exs.shape[-1]\n", | |
"\n", | |
" x_sigmas = jnp.geomspace(x_sigma, max(2, xdim*min_sigma_frac), epochs)\n", | |
" y_sigmas = jnp.geomspace(y_sigma, max(2, ydim*min_sigma_frac), epochs)\n", | |
" \n", | |
" range_x, range_y = jnp.arange(xdim), jnp.arange(ydim)\n", | |
" \n", | |
" hood = None\n", | |
" som = init_som(xdim, ydim, fdim, seed)\n", | |
"\n", | |
" for t in tqdm.trange(epochs):\n", | |
" updates = batch_step(examples, som, range_x, range_y, xdim, ydim, x_sigmas[t], y_sigmas[t])\n", | |
" frame_callback and frame_callback(t, None, hood, som)\n", | |
"\n", | |
" som = jnp.divide(jnp.sum(updates, axis=0), jnp.linalg.norm(np.sum(updates, axis=0), axis=1).reshape(-1, 1))\n", | |
" \n", | |
" frame_callback and frame_callback(t, None, hood, som)\n", | |
" return som" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "138e1ced", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"bfc = HistoryCallback(240,135,3, lambda x: True)\n", | |
"\n", | |
"examples = jax.random.uniform(jax.random.PRNGKey(42), shape=(1000, 3))\n", | |
"\n", | |
"color_som_batch =\\\n", | |
" train_som_batch(examples, \n", | |
" 240, 135, \n", | |
" 120, 70, \n", | |
" 50, min_sigma_frac=.2, seed=42, frame_callback=bfc)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "77e9a8ef", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"px.imshow(color_som_batch.reshape(240,135,3).swapaxes(0,1)).show()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "79a37bd4", | |
"metadata": {}, | |
"source": [ | |
"✅ What would you need to do to rewrite the SOM training to _not_ use `jax.vmap`? (You don't have to actually implement this unless you're interested in a puzzle!)\n", | |
"\n", | |
"\n", | |
"✅ Modify `batch_step` to use an alternate distance metric. This will involve modifying the following line of code: \n", | |
" \n", | |
"```bmu_idx = jnp.argmin(jnp.linalg.norm(som - ex, axis=1))``` \n", | |
" \n", | |
"so that you're taking the `argmin` (or `argmax`, if you're looking for similarity!) of a different function over each entry in the map and the current example. If you don't have a favorite distance or similarity measure, a common example is cosine similarity, which you can calculate for two vectors by dividing their dot product by the product of their magnitudes, like this:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "bc064b33", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"example_som = np.random.random(size=(128, 3))\n", | |
"example_vec = np.random.random(size=(3,))\n", | |
"\n", | |
"jnp.divide(jnp.dot(example_som, example_vec), (jnp.linalg.norm(example_som, axis=1) * jnp.linalg.norm(example_vec)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "663f61a6", | |
"metadata": {}, | |
"source": [ | |
"✅ If you implemented cosine similarity, what change did you notice to the performance of batch training? What changes could you make to `train_som_batch` to improve performance?" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment