Created
October 12, 2023 08:33
-
-
Save maedoc/51925edc367bc10d81b287eb1e88a0c3 to your computer and use it in GitHub Desktop.
Short test of taichi on tight numerical loops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "f4071430-99df-4748-a39c-50d92af79400", | |
"metadata": {}, | |
"source": [ | |
"# taichi vs jax vs numba tight loop comparison\n", | |
"\n", | |
"Taichi is another attempt at high performance in Python, which compares in some interesting ways to existing solutions; it's closer to DrJit & Mojo than Jax or Numba.\n", | |
"\n", | |
"First litmus test, it pip installs on Windows and just works like jax or numba? nope, some differences to note off the bat\n", | |
"- needs `ti.init()` call explicitly 🤷♂️\n", | |
"- arguments require type annotations👍\n", | |
"- (scalar) kernel arguments are immutable, so we need to copy first 👍\n", | |
"- need a return type too" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "64875d89-36d7-46d2-af76-0d465e36c735", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Taichi] version 1.6.0, llvm 15.0.1, commit f1c6fbbd, win, python 3.10.13\n", | |
"[Taichi] Starting on arch=x64\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(4.95027390356502, 4.9502716064453125)" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import taichi as ti\n", | |
"ti.init()\n", | |
"\n", | |
"def foo(n: int, a: float, b: float) -> float:\n", | |
" a_ = a\n", | |
" for i in range(n):\n", | |
" a_ = a_ / b\n", | |
" return a_\n", | |
"\n", | |
"ti_foo = ti.kernel(foo)\n", | |
"\n", | |
"foo(10, 5.0, 1.001), ti_foo(10, 5.0, 1.001)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ed6087e9-c4c6-429d-a4d8-ffd692e1d65c", | |
"metadata": {}, | |
"source": [ | |
"This is a test of (a) kernel launch overhead and (b) perf of tight loop with only two registers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "bbc57c06-03bd-4f00-a800-4f0b9a86f01e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"294 ns ± 5.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
"110 µs ± 4.78 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"\n", | |
"578 ns ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
"110 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"\n", | |
"3.41 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"109 µs ± 2.73 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"\n", | |
"39.3 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"111 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"\n", | |
"422 µs ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", | |
"149 µs ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"for n in [1, 10, 100, 1000, 10_000]:\n", | |
" %timeit foo(n, 5.0, 1.001)\n", | |
" %timeit ti_foo(n, 5.0, 1.001)\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9d3b288a-8017-47b1-aa05-ba084387cd8d", | |
"metadata": {}, | |
"source": [ | |
"so there's a clear launch overhead, but it seems to JIT the tight loop nicely.\n", | |
"\n", | |
"we can see what numba does in this case," | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "d577eecc-20bf-4aca-9446-4d01b751a971", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"221 ns ± 7.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
"221 ns ± 4.36 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
"533 ns ± 25.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
"4.79 µs ± 994 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"37.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"import numba as nb\n", | |
"nb_foo = nb.njit(foo)\n", | |
"nb_foo(1, 1.0, 1.0)\n", | |
"for n in [1, 10, 100, 1000, 10000]:\n", | |
" %timeit nb_foo(n, 5.0, 1.001)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "21867174-357b-468f-a20c-92cd3fca43c0", | |
"metadata": {}, | |
"source": [ | |
"numba is generating more efficient launches, but as long as kernels can be fused, it's not a big deal.\n", | |
"\n", | |
"what about Jax? it jits using tracing with function composition, so we need a slightly different approach," | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "9afef1d1-57f8-4746-9250-538dc4a7d966", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"9.41 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"11.3 µs ± 462 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"11.6 µs ± 826 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"11.5 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"13.6 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"import jax\n", | |
"\n", | |
"def make_jx_foo(n, b):\n", | |
" def jx_foo_loop(c, x):\n", | |
" return c / b, x\n", | |
" @jax.jit\n", | |
" def jx_foo_scan(a):\n", | |
" i = jax.numpy.r_[:n]\n", | |
" return jax.lax.scan(jx_foo_loop, a, i, unroll=20)[0]\n", | |
" return jx_foo_scan\n", | |
"\n", | |
"for n in [1, 10, 100, 1000, 10000]:\n", | |
" # %timeit -n25 -r25 jx_foo(jax.numpy.r_[:n], 5.0)\n", | |
" jx_foo = make_jx_foo(n, 1.001)\n", | |
" jx_foo(5.0)\n", | |
" %timeit jx_foo(5.0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1e9b8f58-bad1-41f2-9782-34adc5755b02", | |
"metadata": {}, | |
"source": [ | |
"so at about 10k iterations, jax matches numba, but has an overhead at first. \n", | |
"\n", | |
"the other major remark is that this is not a convenient way to write the code:\n", | |
"- the loop trip count has to be known at compile time\n", | |
"- can't pass op parameters as arguments\n", | |
"\n", | |
"can we write a better version? turns out yes," | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "4b060902-e3b0-41d1-aee8-81ec353f5347", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"12.9 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"13.4 µs ± 347 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"14.1 µs ± 367 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"16.3 µs ± 337 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
"36 µs ± 704 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"@jax.jit\n", | |
"def jx_foo2(n, a, b):\n", | |
" return jax.lax.fori_loop(0, n, lambda i, a: a*b, a)\n", | |
"\n", | |
"for n in [1, 10, 100, 1000, 10_000]:\n", | |
" jx_foo2(n, 5.0, 1.001)\n", | |
" %timeit jx_foo2(n, 5.0, 1/1.001)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4948a32b-9869-4ae2-be72-d0856e66cccd", | |
"metadata": {}, | |
"source": [ | |
"with the caveat that it's a bit slower, maybe because the constant b is known. using multiply instead of divide closes the gap especially for 10k iterations." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "taichi", | |
"language": "python", | |
"name": "taichi" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment