Created
December 1, 2020 21:03
-
-
Save shoyer/146902a92016d3d936ad0b224910a967 to your computer and use it in GitHub Desktop.
new gmres benchmark.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "new gmres benchmark.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMlCNHS4jQHhr3l8ybYddtv", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shoyer/146902a92016d3d936ad0b224910a967/new-gmres-benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "f9fJ9EInZY0v", | |
"outputId": "b550cd68-35d6-4312-a122-22910bb080e3" | |
}, | |
"source": [ | |
"! pip install -U git+https://github.com/shoyer/jax.git@gmres-cleanup" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting git+https://github.com/shoyer/jax.git@gmres-cleanup\n", | |
" Cloning https://github.com/shoyer/jax.git (to revision gmres-cleanup) to /tmp/pip-req-build-3ym4dw3w\n", | |
" Running command git clone -q https://github.com/shoyer/jax.git /tmp/pip-req-build-3ym4dw3w\n", | |
" Running command git checkout -b gmres-cleanup --track origin/gmres-cleanup\n", | |
" Switched to a new branch 'gmres-cleanup'\n", | |
" Branch 'gmres-cleanup' set up to track remote branch 'gmres-cleanup' from 'origin'.\n", | |
"Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (1.18.5)\n", | |
"Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (0.10.0)\n", | |
"Requirement already satisfied, skipping upgrade: opt_einsum in /usr/local/lib/python3.6/dist-packages (from jax==0.2.6) (3.3.0)\n", | |
"Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax==0.2.6) (1.15.0)\n", | |
"Building wheels for collected packages: jax\n", | |
" Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for jax: filename=jax-0.2.6-cp36-none-any.whl size=606591 sha256=d4a0a4550f5d613321df9534315c72dc39ae5a552163d3be201603eca460b5a7\n", | |
" Stored in directory: /tmp/pip-ephem-wheel-cache-cuqvuuby/wheels/99/39/0d/df246aefe5c610292921f884fdf7709e8bfb9b118f22da8c85\n", | |
"Successfully built jax\n", | |
"Installing collected packages: jax\n", | |
" Found existing installation: jax 0.2.6\n", | |
" Uninstalling jax-0.2.6:\n", | |
" Successfully uninstalled jax-0.2.6\n", | |
"Successfully installed jax-0.2.6\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EuqchR8uZPZu" | |
}, | |
"source": [ | |
"import jax\n", | |
"import jax.config\n", | |
"import scipy as sp\n", | |
"import jax.numpy as jnp\n", | |
"import scipy.sparse.linalg\n", | |
"import numpy as np\n", | |
"\n", | |
"jax.config.update(\"jax_enable_x64\", True)\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"import time\n", | |
"import functools\n" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "kF44TAgMueLJ", | |
"outputId": "5dd0cfd8-b767-415c-ef97-7696a5cfe14c" | |
}, | |
"source": [ | |
"def gmres_incremental(A, b):\n", | |
" f = functools.partial(jnp.dot, A)\n", | |
" return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0)\n", | |
"\n", | |
"def gmres_direct(A, b):\n", | |
" f = functools.partial(jnp.dot, A)\n", | |
" return jax.scipy.sparse.linalg.gmres(f, b, restart=restart, maxiter=1, atol=0, tol=0, solve_method='direct')\n", | |
"\n", | |
"def gmres_scipy(A, b):\n", | |
" return scipy.sparse.linalg.gmres(A_, b_, restart=restart, maxiter=1, atol=0, tol=0)\n", | |
"\n", | |
"for N, restart in [\n", | |
" (20, 10),\n", | |
" (200, 50),\n", | |
" (2000, 200),\n", | |
"]:\n", | |
" print(f\"\\nN={N}, restart={restart}\")\n", | |
" A = jax.random.normal(jax.random.PRNGKey(0), (N, N))\n", | |
" b = jax.random.normal(jax.random.PRNGKey(1), (N,))\n", | |
"\n", | |
" A_, b_ = np.asarray(A), np.asarray(b)\n", | |
" print(\"SciPy CPU:\")\n", | |
" %timeit gmres_scipy(A_, b_)\n", | |
"\n", | |
" print(\"JAX incremental CPU:\")\n", | |
" gmres_ = jax.jit(gmres_incremental, backend='cpu')\n", | |
" gmres_(A, b)[0].block_until_ready()\n", | |
" %timeit gmres_(A, b)[0].block_until_ready()\n", | |
"\n", | |
" print(\"JAX direct CPU:\")\n", | |
" gmres_ = jax.jit(gmres_direct, backend='cpu')\n", | |
" gmres_(A, b)[0].block_until_ready()\n", | |
" %timeit gmres_(A, b)[0].block_until_ready()\n", | |
"\n", | |
" print(\"JAX incremental GPU:\")\n", | |
" gmres_ = jax.jit(gmres_incremental, backend='gpu')\n", | |
" gmres_(A, b)[0].block_until_ready()\n", | |
" %timeit gmres_(A, b)[0].block_until_ready()\n", | |
"\n", | |
" print(\"JAX direct GPU:\")\n", | |
" gmres_ = jax.jit(gmres_direct, backend='gpu')\n", | |
" gmres_(A, b)[0].block_until_ready()\n", | |
" %timeit gmres_(A, b)[0].block_until_ready()\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"N=20, restart=10\n", | |
"SciPy CPU:\n", | |
"1000 loops, best of 3: 355 µs per loop\n", | |
"JAX incremental CPU:\n", | |
"10000 loops, best of 3: 108 µs per loop\n", | |
"JAX direct CPU:\n", | |
"10000 loops, best of 3: 112 µs per loop\n", | |
"JAX incremental GPU:\n", | |
"100 loops, best of 3: 4.83 ms per loop\n", | |
"JAX direct GPU:\n", | |
"100 loops, best of 3: 1.92 ms per loop\n", | |
"\n", | |
"N=200, restart=50\n", | |
"SciPy CPU:\n", | |
"100 loops, best of 3: 2.63 ms per loop\n", | |
"JAX incremental CPU:\n", | |
"1000 loops, best of 3: 762 µs per loop\n", | |
"JAX direct CPU:\n", | |
"1000 loops, best of 3: 1.05 ms per loop\n", | |
"JAX incremental GPU:\n", | |
"10 loops, best of 3: 67.6 ms per loop\n", | |
"JAX direct GPU:\n", | |
"100 loops, best of 3: 8.77 ms per loop\n", | |
"\n", | |
"N=2000, restart=100\n", | |
"SciPy CPU:\n", | |
"10 loops, best of 3: 163 ms per loop\n", | |
"JAX incremental CPU:\n", | |
"10 loops, best of 3: 174 ms per loop\n", | |
"JAX direct CPU:\n", | |
"10 loops, best of 3: 176 ms per loop\n", | |
"JAX incremental GPU:\n", | |
"1 loop, best of 3: 252 ms per loop\n", | |
"JAX direct GPU:\n", | |
"10 loops, best of 3: 38.5 ms per loop\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "TDz03Lx3foQM" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment