Last active
July 13, 2020 06:01
-
-
Save shoyer/dc33a5850337b6a87d48ed97b4727d29 to your computer and use it in GitHub Desktop.
Simple JAX GMRES vectorized
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Simple JAX GMRES vectorized", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/dc33a5850337b6a87d48ed97b4727d29/simple-jax-gmres-vectorized.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "q4LEr98cuYQf", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Simple JAX GMRES\n", | |
"\n", | |
"Author: [email protected]\n", | |
"\n", | |
"Date: July 11, 2020\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "SR3HPqI2q8Th", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Copyright 2020 Google LLC.\n", | |
"# SPDX-License-Identifier: Apache-2.0\n", | |
"import numpy as np\n", | |
"import functools\n", | |
"from jax import random\n", | |
"from jax import lax\n", | |
"import jax.numpy as jnp\n", | |
"import jax.ops\n", | |
"import jax.scipy as jsp\n", | |
"from jax.tree_util import Partial\n", | |
"import scipy.sparse.linalg\n", | |
"from jax.experimental import loops\n", | |
"\n", | |
"def _identity(x):\n", | |
" return x\n", | |
"\n", | |
"_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST)\n", | |
"\n", | |
"def _iterative_classical_gram_schmidt(Q, x, iterations=2):\n", | |
" \"\"\"Orthogonalize x against the columns of Q.\"\"\"\n", | |
" # \"twice is enough\"\n", | |
" # http://slepc.upv.es/documentation/reports/str1.pdf\n", | |
" q = x\n", | |
" r = 0\n", | |
" for _ in range(iterations):\n", | |
" h = _dot(Q.T.conj(), q)\n", | |
" q = q - _dot(Q, h)\n", | |
" r = r + h\n", | |
" return q, r\n", | |
"\n", | |
"def arnoldi_iteration(A, b, n, M=None):\n", | |
" # https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration\n", | |
" if M is None:\n", | |
" M = _identity\n", | |
" m = b.shape[0]\n", | |
" q = b / jnp.linalg.norm(b)\n", | |
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n", | |
" H = jnp.zeros((n, n+1))\n", | |
" def f(carry, k):\n", | |
" Q, H = carry\n", | |
" q = Q[:, k]\n", | |
" v = A(M(q))\n", | |
" v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1)\n", | |
" v_norm = jnp.linalg.norm(v)\n", | |
" Q = Q.at[:, k+1].set(v / v_norm)\n", | |
" h = h.at[k+1].set(v_norm)\n", | |
" H = H.at[k, :].set(h) \n", | |
" return (Q, H), None\n", | |
" (Q, H), _ = lax.scan(f, (Q, H), jnp.arange(n))\n", | |
" return Q, H\n", | |
"\n", | |
"@jax.jit\n", | |
"def lstsq(a, b):\n", | |
" # slightly faster than jnp.linalg.lstsq\n", | |
" return jsp.linalg.solve(_dot(a.T, a), _dot(a.T, b), sym_pos=True)\n", | |
"\n", | |
"def _gmres(A, b, x0, n, M):\n", | |
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n", | |
" # TODO: exit based on acheiving some error tolerance\n", | |
" Q, H = arnoldi_iteration(A, b, n, M)\n", | |
" beta = jnp.linalg.norm(b - A(x0))\n", | |
" e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n", | |
" y = lstsq(H.T, beta * e1)\n", | |
" x = x0 + M(_dot(Q[:, :-1], y))\n", | |
" return x\n", | |
"\n", | |
"def gmres(A, b, x0=None, n=5, M=None):\n", | |
" if x0 is None:\n", | |
" x0 = jnp.zeros_like(b)\n", | |
" if M is None:\n", | |
" M = _identity\n", | |
" return _gmres(A, b, x0, n, M)" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "j5REsLhwxfsc", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Tests" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SVgN9XUExrtj", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Verify correctness:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "pcIJkLIKX7cK", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"A = random.normal(random.PRNGKey(0), (100, 100))\n", | |
"b = random.normal(random.PRNGKey(1), (100,))" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dtJZvMbtq9nH", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"np.testing.assert_allclose(\n", | |
" gmres(functools.partial(jnp.dot, A), b, n=20),\n", | |
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],\n", | |
" atol=1e-6,\n", | |
")" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "IL39G9FZueDG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Verify we can calculate gradients through a fixed number of loops.\n", | |
"\n", | |
"(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the [adjoint rule](https://dolfin-adjoint-doc.readthedocs.io/en/latest/documentation/maths/3-gradients.html#the-adjoint-approach).)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5YbP5hrVtjYZ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"@jax.grad\n", | |
"def loss(A, b):\n", | |
" return jnp.sum(gmres(functools.partial(jnp.dot, A), b))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "BiDj1N3BuJo4", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 238 | |
}, | |
"outputId": "5df39242-89d6-489b-8fe8-850344adf928" | |
}, | |
"source": [ | |
"loss(A, b)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"DeviceArray([[-0.00888865, -0.0110899 , -0.01395201, ..., -0.01434979,\n", | |
" -0.00233699, 0.0087676 ],\n", | |
" [ 0.00685218, 0.00968965, 0.00116033, ..., -0.0108919 ,\n", | |
" -0.00220355, 0.01377206],\n", | |
" [-0.00557139, -0.00477797, -0.01392098, ..., -0.01569233,\n", | |
" -0.00254976, 0.01301789],\n", | |
" ...,\n", | |
" [-0.00446863, -0.00590283, -0.00807492, ..., -0.01217444,\n", | |
" -0.00532266, 0.0111393 ],\n", | |
" [ 0.00431959, 0.00333032, 0.0005375 , ..., 0.00552948,\n", | |
" 0.00076819, -0.0026694 ],\n", | |
" [ 0.00614914, 0.00756271, 0.0005134 , ..., -0.00826863,\n", | |
" -0.00276195, 0.01154379]], dtype=float32)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "2fzwBVr9xhJ_", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Performance" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9XGa2EY3u73b", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Despite our naive implementation, out of the box performance beats SciPy by about 3x:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4u1SmaDKFbHz", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"@functools.partial(jax.jit, static_argnums=(2,))\n", | |
"def explicit_gmres(A, b, n):\n", | |
" return gmres(functools.partial(jnp.dot, A), b, n=n)" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lOoQpMb4tWTD", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "300703d3-272b-4972-8056-39c7cca7e7f2" | |
}, | |
"source": [ | |
"# scipy CPU\n", | |
"b2 = np.asarray(b)\n", | |
"A2 = np.asarray(A)\n", | |
"%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1000 loops, best of 3: 1.49 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JM4JrERzwfrk", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "990513c4-bc4d-4255-e81f-d0aa6c45bdd4" | |
}, | |
"source": [ | |
"# CPU\n", | |
"%timeit explicit_gmres(A, b, 30).block_until_ready()" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"The slowest run took 6.77 times longer than the fastest. This could mean that an intermediate result is being cached.\n", | |
"1000 loops, best of 3: 499 µs per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Z-3mUTzazOjg", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"GPU is bit slower (for this matrix size), because there's not enough compute happening inside each loop iteration:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XEuNmYR4r5-H", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "0ddeb0b5-a30f-4705-d08b-f885b97d081e" | |
}, | |
"source": [ | |
"# GPU\n", | |
"%timeit explicit_gmres(A, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"100 loops, best of 3: 2.66 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "GbyGUXO9xeps", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We can also `vmap` it! This gives us a big speed-up on GPUs:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-rYZBcm9uM3-", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EUZPgiciHABY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JA1teSvJwbIA", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "d28dcb9c-fd39-4eb1-d2fb-dc1d0a7f1478" | |
}, | |
"source": [ | |
"# CPU\n", | |
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"1 loop, best of 3: 416 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "DIQxwSQj1R9I", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "7e734a9e-361a-4acb-f19b-d8a4c332c4b8" | |
}, | |
"source": [ | |
"# GPU\n", | |
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"10 loops, best of 3: 24.5 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9Hvg7mMzIaTY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment