-
-
Save jpivarski/da343abd8024834ee8c5aaba691aafc7 to your computer and use it in GitHub Desktop.
This is wonderful, thanks for sharing!
Thank you @jpivarski -- this notebook has been extremely helpful in getting me started using Jax for fractal computation on a GPU.
However I just wanted to let you know (and any others) that the Jax routine performs loop unrolling, which converts the loop of 20 iterations into code that simply gets repeated and compiled as 20 consecutive blocks of code.
This isn't noticeable when only performing 20 iterations, but quickly becomes unmanageable with a more realistic fractal calculation of, say, 1,000 iterations -- which turns out to take something like 60 seconds to compile for its first run.
After I asked elsewhere for help, I was informed the solution was to use lax.fori_loop instead of a traditional Python loop, which ensures a genuine loop is actually compiled and performed.
And while this compiled almost instantly, I'll note that this came at a post-compilation performance cost -- instead of the 1,000 iterations executing in ~80 ms, it took ~1,000 ms.
However, I then managed to "get the best of both worlds" by combining the two methods, and doing 10 loops of lax.fori_loop
which, within it, did 100 iterations of a traditional Python loop (that was unrolled for performance).
I'm sure you probably don't want to do anything like update this guide to compare 1,000 iterations instead of 20, but I thought you might want to know that this does add a wrinkle at least to Jax, and I don't know about the other GPU solutions as well.
Thanks for pointing this out! This is worth knowing and I'll mention it whenever I show this example.
This is really cool! Would be interesting to see results from CPU parallelism in addition, single CPU core vs GPU isn't really a fair comparison. E.g. Jax and Numba's parallelism support. But totally understand if you don't have time at this point.
In a variation of this, a long-form project at a Fast and Efficient Python Computing School, I have an example that uses parallel processing on CPUs (In[18]
):
rng = np.random.default_rng() # can be forked to run multiple rngs in parallel
rngs = rng.spawn(NUM_TILES_1D * NUM_TILES_1D)
@nb.jit(parallel=True)
def compute_parallel(rngs, numer, denom):
for i in nb.prange(NUM_TILES_1D):
for j in nb.prange(NUM_TILES_1D):
rng = rngs[NUM_TILES_1D * i + j]
denom[i, j] = 100
numer[i, j] = count_mandelbrot(rng, denom[i, j], xmin(j), width, ymin(i), height)
numer = np.zeros((NUM_TILES_1D, NUM_TILES_1D), dtype=np.int64)
denom = np.zeros((NUM_TILES_1D, NUM_TILES_1D), dtype=np.int64)
compute_parallel(rngs, numer, denom)
When the job is split into embarrassingly parallel tasks, like the non-overlapping tiles/parts of the complex plane above, then there aren't many reasons why the CPU rate can't just be multiplied by the number of parallel threads. "Not many": maybe false sharing of parts of the output array, memory bus contention at really high rates (Gbps), and certainly stragglers, which would affect this problem strongly—if some tiles are much slower than others (because they're on the complicated boundary of the fractal), then they'll still be running while others have finished and only a few threads will be active. But that wouldn't be intrinsic to a CPU/GPU comparison, and you'd optimize the workflow (by making more, smaller tiles, which are less susceptible to stragglers), and that would have closer to linear scaling with the number of threads.
To ensure that all of the CPU numbers were single-threaded, I even had to force JAX into a box that it couldn't escape:
It's hard to keep JAX from using all your CPU cores, so I ran this notebook with
taskset -c 0 jupyter labto bind it to exactly one core (the one numbered
0
). The following assertion will fail and you would get misleading results if you don't run this notebook under the same conditions.
That was to make sure that the results are interpretable—not that you'd pit one CPU core against a GPU, but that you'd scale by the number of CPU cores that you actually have (and optimize the workflow). That was the idea, anyway.
Makes sense, thank you!
This is missing PyPy numbers.
Okay. I installed pypy 7.3.15 and CPython 3.9.18 in two conda environments with NumPy in each and repeated In [2]
and In [4]
above.
Here is the CPython, which reproduces the order of magnitude above: I find that it's now 145 ms per run_python(200, 300)
, rather than 249 ms... within a factor of 2, with a different Python version, environment, and who-knows-what since I first ran this test on June 14, 2022.
Python 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10)
[GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import timeit
>>> import numpy as np
>>> def run_python(height, width):
... y, x = np.ogrid[-1:0:height*1j, -1.5:0:width*1j]
... c = x + y*1j
... fractal = np.full(c.shape, 20, dtype=np.int32)
... for h in range(height):
... for w in range(width): # for each pixel (h, w)...
... z = c[h, w]
... for i in range(20): # iterate at most 20 times
... z = z**2 + c[h, w] # applying z → z² + c
... if abs(z) > 2: # if it diverges (|z| > 2)
... fractal[h, w] = i # color the plane with the iteration number
... break # we're done, no need to keep iterating
... return fractal
...
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.442052590999083
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.4416049659994314
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.445725762998336
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.4504199469993182
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.4678468480015
Now here's the pypy. You can see that it's matching the same interface-version of Python.
Python 3.9.18 | packaged by conda-forge | (9c4f8ef1, Mar 08 2024, 07:32:51)
[PyPy 7.3.15 with GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>> import timeit
>>>> import numpy as np
>>>> def run_python(height, width):
.... y, x = np.ogrid[-1:0:height*1j, -1.5:0:width*1j]
.... c = x + y*1j
.... fractal = np.full(c.shape, 20, dtype=np.int32)
.... for h in range(height):
.... for w in range(width): # for each pixel (h, w)...
.... z = c[h, w]
.... for i in range(20): # iterate at most 20 times
.... z = z**2 + c[h, w] # applying z → z² + c
.... if abs(z) > 2: # if it diverges (|z| > 2)
.... fractal[h, w] = i # color the plane with the iteration number
.... break # we're done, no need to keep iterating
.... return fractal
....
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
8.935209687999304
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
8.972501267999178
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
8.909452853000403
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
8.971896891998767
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
8.998815111001022
The pypy is more than 6× slower than the CPython. I don't know why: I would have thought that pypy's JIT compiler would have made it several times faster than CPython, but nothing like the factor of 200× faster that you can get by compiling without maintaining Python's dynamic programming environment (that is, C++, Numba, JAX...).
My experience is that PyPy doesn't interact well with NumPy. I think this is part of the motivation for the HPy project (https://hpyproject.org/).
Good point... CPython is a little faster when dealing with its own builtin types, too.
Replacing the NumPy data structure with a list of lists (in the hot part of the loop), here's CPython again:
Python 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10)
[GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import timeit
>>> import numpy as np
>>> def run_python(height, width):
... y, x = np.ogrid[-1:0:height*1j, -1.5:0:width*1j]
... c = x + y*1j
... fractal = np.full(c.shape, 20, dtype=np.int32).tolist()
... x, y, c = x.tolist(), y.tolist(), c.tolist()
... for h in range(height):
... for w in range(width): # for each pixel (h, w)...
... z = c[h][w]
... for i in range(20): # iterate at most 20 times
... z = z**2 + c[h][w] # applying z → z² + c
... if abs(z) > 2: # if it diverges (|z| > 2)
... fractal[h][w] = i # color the plane with the iteration number
... break # we're done, no need to keep iterating
... return fractal
...
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
1.000359961999493
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.9851951990021917
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.9806630249986483
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.9808432169993466
>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.9776434310006152
and here's pypy:
Python 3.9.18 | packaged by conda-forge | (9c4f8ef1, Mar 08 2024, 07:32:51)
[PyPy 7.3.15 with GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>> import timeit
>>>> import numpy as np
>>>> def run_python(height, width):
.... y, x = np.ogrid[-1:0:height*1j, -1.5:0:width*1j]
.... c = x + y*1j
.... fractal = np.full(c.shape, 20, dtype=np.int32).tolist()
.... x, y, c = x.tolist(), y.tolist(), c.tolist()
.... for h in range(height):
.... for w in range(width): # for each pixel (h, w)...
.... z = c[h][w]
.... for i in range(20): # iterate at most 20 times
.... z = z**2 + c[h][w] # applying z → z² + c
.... if abs(z) > 2: # if it diverges (|z| > 2)
.... fractal[h][w] = i # color the plane with the iteration number
.... break # we're done, no need to keep iterating
.... return fractal
....
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.6174074949994974
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.6332233270004508
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.7321933960010938
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.6341267679999874
>>>> timeit.timeit(lambda: run_python(200, 300), number=10)
0.6687725560004765
Now we get pypy being about 1.7× faster than CPython, which is in the ballpark of what I'd expect.
Personally, I'm still a lot more swayed by the 200× that you get through other methods. For any numerical work, I'd try to get the operation on numerical data compiled with known types, no boxing, no garbage collectors, and all the rest.
I got intrigued and went and found a dual scalar / handwritten portable SIMD implementation of the Mandelbrot algorithm: https://pythonspeed.com/articles/optimizing-with-simd/
The development of this notebook was discussed on jax-ml/jax#11078.