Skip to content

Instantly share code, notes, and snippets.

@maedoc
Created October 12, 2023 08:33
Show Gist options
  • Save maedoc/51925edc367bc10d81b287eb1e88a0c3 to your computer and use it in GitHub Desktop.
Save maedoc/51925edc367bc10d81b287eb1e88a0c3 to your computer and use it in GitHub Desktop.
Short test of taichi on tight numerical loops
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "f4071430-99df-4748-a39c-50d92af79400",
"metadata": {},
"source": [
"# taichi vs jax vs numba tight loop comparison\n",
"\n",
"Taichi is another attempt at high performance in Python, which compares in some interesting ways to existing solutions; it's closer to DrJit & Mojo than Jax or Numba.\n",
"\n",
"First litmus test, it pip installs on Windows and just works like jax or numba? nope, some differences to note off the bat\n",
"- needs `ti.init()` call explicitly 🤷‍♂️\n",
"- arguments require type annotations👍\n",
"- (scalar) kernel arguments are immutable, so we need to copy first 👍\n",
"- need a return type too"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "64875d89-36d7-46d2-af76-0d465e36c735",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Taichi] version 1.6.0, llvm 15.0.1, commit f1c6fbbd, win, python 3.10.13\n",
"[Taichi] Starting on arch=x64\n"
]
},
{
"data": {
"text/plain": [
"(4.95027390356502, 4.9502716064453125)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import taichi as ti\n",
"ti.init()\n",
"\n",
"def foo(n: int, a: float, b: float) -> float:\n",
" a_ = a\n",
" for i in range(n):\n",
" a_ = a_ / b\n",
" return a_\n",
"\n",
"ti_foo = ti.kernel(foo)\n",
"\n",
"foo(10, 5.0, 1.001), ti_foo(10, 5.0, 1.001)"
]
},
{
"cell_type": "markdown",
"id": "ed6087e9-c4c6-429d-a4d8-ffd692e1d65c",
"metadata": {},
"source": [
"This is a test of (a) kernel launch overhead and (b) perf of tight loop with only two registers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bbc57c06-03bd-4f00-a800-4f0b9a86f01e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"294 ns ± 5.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"110 µs ± 4.78 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"\n",
"578 ns ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"110 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"\n",
"3.41 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"109 µs ± 2.73 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"\n",
"39.3 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"111 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"\n",
"422 µs ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n",
"149 µs ± 13.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"\n"
]
}
],
"source": [
"for n in [1, 10, 100, 1000, 10_000]:\n",
" %timeit foo(n, 5.0, 1.001)\n",
" %timeit ti_foo(n, 5.0, 1.001)\n",
" print()"
]
},
{
"cell_type": "markdown",
"id": "9d3b288a-8017-47b1-aa05-ba084387cd8d",
"metadata": {},
"source": [
"so there's a clear launch overhead, but it seems to JIT the tight loop nicely.\n",
"\n",
"we can see what numba does in this case,"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d577eecc-20bf-4aca-9446-4d01b751a971",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"221 ns ± 7.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"221 ns ± 4.36 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"533 ns ± 25.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"4.79 µs ± 994 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"37.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"import numba as nb\n",
"nb_foo = nb.njit(foo)\n",
"nb_foo(1, 1.0, 1.0)\n",
"for n in [1, 10, 100, 1000, 10000]:\n",
" %timeit nb_foo(n, 5.0, 1.001)"
]
},
{
"cell_type": "markdown",
"id": "21867174-357b-468f-a20c-92cd3fca43c0",
"metadata": {},
"source": [
"numba is generating more efficient launches, but as long as kernels can be fused, it's not a big deal.\n",
"\n",
"what about Jax? it jits using tracing with function composition, so we need a slightly different approach,"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9afef1d1-57f8-4746-9250-538dc4a7d966",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"9.41 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"11.3 µs ± 462 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"11.6 µs ± 826 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"11.5 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"13.6 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"import jax\n",
"\n",
"def make_jx_foo(n, b):\n",
" def jx_foo_loop(c, x):\n",
" return c / b, x\n",
" @jax.jit\n",
" def jx_foo_scan(a):\n",
" i = jax.numpy.r_[:n]\n",
" return jax.lax.scan(jx_foo_loop, a, i, unroll=20)[0]\n",
" return jx_foo_scan\n",
"\n",
"for n in [1, 10, 100, 1000, 10000]:\n",
" # %timeit -n25 -r25 jx_foo(jax.numpy.r_[:n], 5.0)\n",
" jx_foo = make_jx_foo(n, 1.001)\n",
" jx_foo(5.0)\n",
" %timeit jx_foo(5.0)"
]
},
{
"cell_type": "markdown",
"id": "1e9b8f58-bad1-41f2-9782-34adc5755b02",
"metadata": {},
"source": [
"so at about 10k iterations, jax matches numba, but has an overhead at first. \n",
"\n",
"the other major remark is that this is not a convenient way to write the code:\n",
"- the loop trip count has to be known at compile time\n",
"- can't pass op parameters as arguments\n",
"\n",
"can we write a better version? turns out yes,"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4b060902-e3b0-41d1-aee8-81ec353f5347",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"12.9 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"13.4 µs ± 347 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"14.1 µs ± 367 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"16.3 µs ± 337 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"36 µs ± 704 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"@jax.jit\n",
"def jx_foo2(n, a, b):\n",
" return jax.lax.fori_loop(0, n, lambda i, a: a*b, a)\n",
"\n",
"for n in [1, 10, 100, 1000, 10_000]:\n",
" jx_foo2(n, 5.0, 1.001)\n",
" %timeit jx_foo2(n, 5.0, 1/1.001)"
]
},
{
"cell_type": "markdown",
"id": "4948a32b-9869-4ae2-be72-d0856e66cccd",
"metadata": {},
"source": [
"with the caveat that it's a bit slower, maybe because the constant b is known. using multiply instead of divide closes the gap especially for 10k iterations."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "taichi",
"language": "python",
"name": "taichi"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment