Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created December 1, 2020 21:03
Show Gist options
  • Save shoyer/146902a92016d3d936ad0b224910a967 to your computer and use it in GitHub Desktop.
Save shoyer/146902a92016d3d936ad0b224910a967 to your computer and use it in GitHub Desktop.
new gmres benchmark.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "new gmres benchmark.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMlCNHS4jQHhr3l8ybYddtv",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shoyer/146902a92016d3d936ad0b224910a967/new-gmres-benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f9fJ9EInZY0v",
"outputId": "b550cd68-35d6-4312-a122-22910bb080e3"
},
"source": [
"! pip install -U git+https://github.com/shoyer/jax.git@gmres-cleanup"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting git+https://github.com/shoyer/jax.git@gmres-cleanup\n",
" Cloning https://github.com/shoyer/jax.git (to revision gmres-cleanup) to /tmp/pip-req-build-3ym4dw3w\n",
" Running command git clone -q https://github.com/shoyer/jax.git /tmp/pip-req-build-3ym4dw3w\n",
" Running command git checkout -b gmres-cleanup --track origin/gmres-cleanup\n",
" Switched to a new branch 'gmres-cleanup'\n",
" Branch 'gmres-cleanup' set up to track remote branch 'gmres-cleanup' from 'origin'.\n",
"Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (0.10.0)\n",
"Requirement already satisfied, skipping upgrade: opt_einsum in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (3.3.0)\n",
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax==0.2.6) (1.15.0)\n",
"Building wheels for collected packages: jax\n",
" Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for jax: filename=jax-0.2.6-cp36-none-any.whl size=606591 sha256=d4a0a4550f5d613321df9534315c72dc39ae5a552163d3be201603eca460b5a7\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-cuqvuuby/wheels/99/39/0d/df246aefe5c610292921f884fdf7709e8bfb9b118f22da8c85\n",
"Successfully built jax\n",
"Installing collected packages: jax\n",
" Found existing installation: jax 0.2.6\n",
" Uninstalling jax-0.2.6:\n",
" Successfully uninstalled jax-0.2.6\n",
"Successfully installed jax-0.2.6\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EuqchR8uZPZu"
},
"source": [
"import jax\n",
"import jax.config\n",
"import scipy as sp\n",
"import jax.numpy as jnp\n",
"import scipy.sparse.linalg\n",
"import numpy as np\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import functools\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kF44TAgMueLJ",
"outputId": "5dd0cfd8-b767-415c-ef97-7696a5cfe14c"
},
"source": [
"def gmres_incremental(A, b):\n",
" f = functools.partial(jnp.dot, A)\n",
" return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0)\n",
"\n",
"def gmres_direct(A, b):\n",
" f = functools.partial(jnp.dot, A)\n",
" return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0, solve_method='direct')\n",
"\n",
"def gmres_scipy(A, b):\n",
" return scipy.sparse.linalg.gmres(A_, b_, restart=restart, maxiter=1, atol=0, tol=0)\n",
"\n",
"for N, restart in [\n",
" (20, 10),\n",
" (200, 50),\n",
" (2000, 200),\n",
"]:\n",
" print(f\"\\nN={N}, restart={restart}\")\n",
" A = jax.random.normal(jax.random.PRNGKey(0), (N, N))\n",
" b = jax.random.normal(jax.random.PRNGKey(1), (N,))\n",
"\n",
" A_, b_ = np.asarray(A), np.asarray(b)\n",
" print(\"SciPy CPU:\")\n",
" %timeit gmres_scipy(A_, b_)\n",
"\n",
" print(\"JAX incremental CPU:\")\n",
" gmres_ = jax.jit(gmres_incremental, backend='cpu')\n",
" gmres_(A, b)[0].block_until_ready()\n",
" %timeit gmres_(A, b)[0].block_until_ready()\n",
"\n",
" print(\"JAX direct CPU:\")\n",
" gmres_ = jax.jit(gmres_direct, backend='cpu')\n",
" gmres_(A, b)[0].block_until_ready()\n",
" %timeit gmres_(A, b)[0].block_until_ready()\n",
"\n",
" print(\"JAX incremental GPU:\")\n",
" gmres_ = jax.jit(gmres_incremental, backend='gpu')\n",
" gmres_(A, b)[0].block_until_ready()\n",
" %timeit gmres_(A, b)[0].block_until_ready()\n",
"\n",
" print(\"JAX direct GPU:\")\n",
" gmres_ = jax.jit(gmres_direct, backend='gpu')\n",
" gmres_(A, b)[0].block_until_ready()\n",
" %timeit gmres_(A, b)[0].block_until_ready()\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"N=20, restart=10\n",
"SciPy CPU:\n",
"1000 loops, best of 3: 355 µs per loop\n",
"JAX incremental CPU:\n",
"10000 loops, best of 3: 108 µs per loop\n",
"JAX direct CPU:\n",
"10000 loops, best of 3: 112 µs per loop\n",
"JAX incremental GPU:\n",
"100 loops, best of 3: 4.83 ms per loop\n",
"JAX direct GPU:\n",
"100 loops, best of 3: 1.92 ms per loop\n",
"\n",
"N=200, restart=50\n",
"SciPy CPU:\n",
"100 loops, best of 3: 2.63 ms per loop\n",
"JAX incremental CPU:\n",
"1000 loops, best of 3: 762 µs per loop\n",
"JAX direct CPU:\n",
"1000 loops, best of 3: 1.05 ms per loop\n",
"JAX incremental GPU:\n",
"10 loops, best of 3: 67.6 ms per loop\n",
"JAX direct GPU:\n",
"100 loops, best of 3: 8.77 ms per loop\n",
"\n",
"N=2000, restart=100\n",
"SciPy CPU:\n",
"10 loops, best of 3: 163 ms per loop\n",
"JAX incremental CPU:\n",
"10 loops, best of 3: 174 ms per loop\n",
"JAX direct CPU:\n",
"10 loops, best of 3: 176 ms per loop\n",
"JAX incremental GPU:\n",
"1 loop, best of 3: 252 ms per loop\n",
"JAX direct GPU:\n",
"10 loops, best of 3: 38.5 ms per loop\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TDz03Lx3foQM"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment