Skip to content

Instantly share code, notes, and snippets.

@shoyer
Last active July 13, 2020 06:01
Show Gist options
  • Save shoyer/dc33a5850337b6a87d48ed97b4727d29 to your computer and use it in GitHub Desktop.
Save shoyer/dc33a5850337b6a87d48ed97b4727d29 to your computer and use it in GitHub Desktop.
Simple JAX GMRES vectorized
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Simple JAX GMRES vectorized",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/dc33a5850337b6a87d48ed97b4727d29/simple-jax-gmres-vectorized.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4LEr98cuYQf",
"colab_type": "text"
},
"source": [
"# Simple JAX GMRES\n",
"\n",
"Author: [email protected]\n",
"\n",
"Date: July 11, 2020\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "SR3HPqI2q8Th",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC.\n",
"# SPDX-License-Identifier: Apache-2.0\n",
"import numpy as np\n",
"import functools\n",
"from jax import random\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.ops\n",
"import jax.scipy as jsp\n",
"from jax.tree_util import Partial\n",
"import scipy.sparse.linalg\n",
"from jax.experimental import loops\n",
"\n",
"def _identity(x):\n",
" return x\n",
"\n",
"_dot = functools.partial(jnp.dot, precision=lax.Precision.HIGHEST)\n",
"\n",
"def _iterative_classical_gram_schmidt(Q, x, iterations=2):\n",
" \"\"\"Orthogonalize x against the columns of Q.\"\"\"\n",
" # \"twice is enough\"\n",
" # http://slepc.upv.es/documentation/reports/str1.pdf\n",
" q = x\n",
" r = 0\n",
" for _ in range(iterations):\n",
" h = _dot(Q.T.conj(), q)\n",
" q = q - _dot(Q, h)\n",
" r = r + h\n",
" return q, r\n",
"\n",
"def arnoldi_iteration(A, b, n, M=None):\n",
" # https://en.wikipedia.org/wiki/Arnoldi_iteration#The_Arnoldi_iteration\n",
" if M is None:\n",
" M = _identity\n",
" m = b.shape[0]\n",
" q = b / jnp.linalg.norm(b)\n",
" Q = jnp.concatenate([q[:, jnp.newaxis], jnp.zeros((m, n))], axis=1)\n",
" H = jnp.zeros((n, n+1))\n",
" def f(carry, k):\n",
" Q, H = carry\n",
" q = Q[:, k]\n",
" v = A(M(q))\n",
" v, h = _iterative_classical_gram_schmidt(Q, v, iterations=1)\n",
" v_norm = jnp.linalg.norm(v)\n",
" Q = Q.at[:, k+1].set(v / v_norm)\n",
" h = h.at[k+1].set(v_norm)\n",
" H = H.at[k, :].set(h) \n",
" return (Q, H), None\n",
" (Q, H), _ = lax.scan(f, (Q, H), jnp.arange(n))\n",
" return Q, H\n",
"\n",
"@jax.jit\n",
"def lstsq(a, b):\n",
" # slightly faster than jnp.linalg.lstsq\n",
" return jsp.linalg.solve(_dot(a.T, a), _dot(a.T, b), sym_pos=True)\n",
"\n",
"def _gmres(A, b, x0, n, M):\n",
" # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf\n",
" # TODO: exit based on acheiving some error tolerance\n",
" Q, H = arnoldi_iteration(A, b, n, M)\n",
" beta = jnp.linalg.norm(b - A(x0))\n",
" e1 = jnp.concatenate([jnp.ones((1,)), jnp.zeros((n,))])\n",
" y = lstsq(H.T, beta * e1)\n",
" x = x0 + M(_dot(Q[:, :-1], y))\n",
" return x\n",
"\n",
"def gmres(A, b, x0=None, n=5, M=None):\n",
" if x0 is None:\n",
" x0 = jnp.zeros_like(b)\n",
" if M is None:\n",
" M = _identity\n",
" return _gmres(A, b, x0, n, M)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "j5REsLhwxfsc",
"colab_type": "text"
},
"source": [
"## Tests"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SVgN9XUExrtj",
"colab_type": "text"
},
"source": [
"Verify correctness:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pcIJkLIKX7cK",
"colab_type": "code",
"colab": {}
},
"source": [
"A = random.normal(random.PRNGKey(0), (100, 100))\n",
"b = random.normal(random.PRNGKey(1), (100,))"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dtJZvMbtq9nH",
"colab_type": "code",
"colab": {}
},
"source": [
"np.testing.assert_allclose(\n",
" gmres(functools.partial(jnp.dot, A), b, n=20),\n",
" scipy.sparse.linalg.gmres(np.array(A), np.array(b), restart=20, maxiter=1)[0],\n",
" atol=1e-6,\n",
")"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "IL39G9FZueDG",
"colab_type": "text"
},
"source": [
"Verify we can calculate gradients through a fixed number of loops.\n",
"\n",
"(Note that if you're running GMRES to convergence, there's a better way to calculate gradients via the [adjoint rule](https://dolfin-adjoint-doc.readthedocs.io/en/latest/documentation/maths/3-gradients.html#the-adjoint-approach).)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5YbP5hrVtjYZ",
"colab_type": "code",
"colab": {}
},
"source": [
"@jax.grad\n",
"def loss(A, b):\n",
" return jnp.sum(gmres(functools.partial(jnp.dot, A), b))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BiDj1N3BuJo4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "5df39242-89d6-489b-8fe8-850344adf928"
},
"source": [
"loss(A, b)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[-0.00888865, -0.0110899 , -0.01395201, ..., -0.01434979,\n",
" -0.00233699, 0.0087676 ],\n",
" [ 0.00685218, 0.00968965, 0.00116033, ..., -0.0108919 ,\n",
" -0.00220355, 0.01377206],\n",
" [-0.00557139, -0.00477797, -0.01392098, ..., -0.01569233,\n",
" -0.00254976, 0.01301789],\n",
" ...,\n",
" [-0.00446863, -0.00590283, -0.00807492, ..., -0.01217444,\n",
" -0.00532266, 0.0111393 ],\n",
" [ 0.00431959, 0.00333032, 0.0005375 , ..., 0.00552948,\n",
" 0.00076819, -0.0026694 ],\n",
" [ 0.00614914, 0.00756271, 0.0005134 , ..., -0.00826863,\n",
" -0.00276195, 0.01154379]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2fzwBVr9xhJ_",
"colab_type": "text"
},
"source": [
"## Performance"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9XGa2EY3u73b",
"colab_type": "text"
},
"source": [
"Despite our naive implementation, out of the box performance beats SciPy by about 3x:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4u1SmaDKFbHz",
"colab_type": "code",
"colab": {}
},
"source": [
"@functools.partial(jax.jit, static_argnums=(2,))\n",
"def explicit_gmres(A, b, n):\n",
" return gmres(functools.partial(jnp.dot, A), b, n=n)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lOoQpMb4tWTD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "300703d3-272b-4972-8056-39c7cca7e7f2"
},
"source": [
"# scipy CPU\n",
"b2 = np.asarray(b)\n",
"A2 = np.asarray(A)\n",
"%timeit scipy.sparse.linalg.gmres(A2, b2, restart=30, maxiter=1)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"1000 loops, best of 3: 1.49 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JM4JrERzwfrk",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "990513c4-bc4d-4255-e81f-d0aa6c45bdd4"
},
"source": [
"# CPU\n",
"%timeit explicit_gmres(A, b, 30).block_until_ready()"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"The slowest run took 6.77 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000 loops, best of 3: 499 µs per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z-3mUTzazOjg",
"colab_type": "text"
},
"source": [
"GPU is bit slower (for this matrix size), because there's not enough compute happening inside each loop iteration:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XEuNmYR4r5-H",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "0ddeb0b5-a30f-4705-d08b-f885b97d081e"
},
"source": [
"# GPU\n",
"%timeit explicit_gmres(A, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"100 loops, best of 3: 2.66 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GbyGUXO9xeps",
"colab_type": "text"
},
"source": [
"We can also `vmap` it! This gives us a big speed-up on GPUs:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-rYZBcm9uM3-",
"colab_type": "code",
"colab": {}
},
"source": [
"A_stack = random.normal(random.PRNGKey(0), (1000, 100, 100))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "EUZPgiciHABY",
"colab_type": "code",
"colab": {}
},
"source": [
"stacked_explicit_gmres = jax.jit(jax.vmap(explicit_gmres, in_axes=(0, None, None)), static_argnums=(2,))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JA1teSvJwbIA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "d28dcb9c-fd39-4eb1-d2fb-dc1d0a7f1478"
},
"source": [
"# CPU\n",
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"1 loop, best of 3: 416 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DIQxwSQj1R9I",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "7e734a9e-361a-4acb-f19b-d8a4c332c4b8"
},
"source": [
"# GPU\n",
"%timeit stacked_explicit_gmres(A_stack, b, 30).block_until_ready()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"10 loops, best of 3: 24.5 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Hvg7mMzIaTY",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment