Created
February 25, 2022 18:29
-
-
Save lgarrison/3dcf371cb9edb891cef92fe3db5c7d33 to your computer and use it in GitHub Desktop.
optimize calc_fenv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c15d3118", | |
"metadata": {}, | |
"source": [ | |
"The original calc_fenv() gave wrong results with `parallel=True` becuase of a bad parfor fusion (a Numba bug). There is a minimal reproducer at the bottom of the notebook that demonstrates the issue. But it turns out that a small manipulation of the code avoids the bug entirely, allowing us to use `parallel=True` safely.\n", | |
"\n", | |
"If one is paranoid, one can also disable parfor fusion entirely with `njit(parallel=dict(fusion=False))`, at some performance penalty.\n", | |
"\n", | |
"The original Numba bug will be fixed by: https://github.com/numba/numba/pull/7582" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "9abe8d6d-d0c2-4c1a-a911-831029816c34", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from numba import njit\n", | |
"import numba" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f3ed5448", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@njit(parallel=True)\n", | |
"def calc_fenv_opt(Menv, mbins, halosM):\n", | |
" fenv_rank = np.zeros(len(Menv))\n", | |
" for ibin in numba.prange(len(mbins)-1):\n", | |
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])\n", | |
" Nmask = np.sum(mmask)\n", | |
" if Nmask > 1:\n", | |
" new_fenv_rank = Menv[mmask].argsort().argsort()\n", | |
" fenv_rank[mmask] = new_fenv_rank / (Nmask-1) - 0.5 # max rank is always Nmask - 1\n", | |
" return fenv_rank\n", | |
"\n", | |
"def calc_fenv_orig(Menv, mbins, halosM):\n", | |
" fenv_rank = np.zeros(len(Menv))\n", | |
" for ibin in numba.prange(len(mbins)-1):\n", | |
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])\n", | |
" if np.sum(mmask) > 1:\n", | |
" new_fenv_rank = Menv[mmask].argsort().argsort()\n", | |
" fenv_rank[mmask] = new_fenv_rank / np.max(new_fenv_rank) - 0.5\n", | |
" return fenv_rank\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "ac5706fe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2.78 s ± 32.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
"False\n", | |
"376 ms ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
"True\n" | |
] | |
} | |
], | |
"source": [ | |
"N = 10**7\n", | |
"mbins = np.linspace(0, 1, 10)\n", | |
"rng = np.random.default_rng()\n", | |
"halosM = rng.random(N)\n", | |
"Menv = rng.random(N)\n", | |
"numba.set_num_threads(12)\n", | |
"\n", | |
"parallel_calc_fenv_orig = njit(parallel=True)(calc_fenv_orig)\n", | |
"serial_calc_fenv_orig = njit(parallel=False)(calc_fenv_orig)\n", | |
"\n", | |
"parres_orig = parallel_calc_fenv_orig(Menv, mbins, halosM)\n", | |
"%timeit global serres_orig; serres_orig = serial_calc_fenv_orig(Menv, mbins, halosM) # 2.8 sec\n", | |
"\n", | |
"print((parres_orig == serres_orig).all()) # False\n", | |
"\n", | |
"%timeit global parres_opt; parres_opt = calc_fenv_opt(Menv, mbins, halosM) # 380 ms\n", | |
"\n", | |
"print((parres_opt == serres_orig).all()) # True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e6d57adb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" \n", | |
"================================================================================\n", | |
" Parallel Accelerator Optimizing: Function calc_fenv_orig, \n", | |
"/tmp/ipykernel_883107/1823462690.py (12) \n", | |
"================================================================================\n", | |
"\n", | |
"\n", | |
"Parallel loop listing for Function calc_fenv_orig, /tmp/ipykernel_883107/1823462690.py (12) \n", | |
"---------------------------------------------------------------------------------|loop #ID\n", | |
"def calc_fenv_orig(Menv, mbins, halosM): | \n", | |
" fenv_rank = np.zeros(len(Menv))----------------------------------------------| #0\n", | |
" for ibin in numba.prange(len(mbins)-1):--------------------------------------| #5\n", | |
" mmask = (halosM > mbins[ibin]) & (halosM < mbins[ibin + 1])--------------| #1\n", | |
" if np.sum(mmask) > 1:----------------------------------------------------| #3\n", | |
" new_fenv_rank = Menv[mmask].argsort().argsort() | \n", | |
" fenv_rank[mmask] = new_fenv_rank / np.max(new_fenv_rank) - 0.5-------| #2, 4\n", | |
" return fenv_rank | \n", | |
"------------------------------ After Optimisation ------------------------------\n", | |
"Parallel region 0:\n", | |
"+--5 (parallel)\n", | |
" +--1 (serial, fused with loop(s): 3)\n", | |
" +--4 (serial, fused with loop(s): 2)\n", | |
"\n", | |
"\n", | |
" \n", | |
"Parallel region 0 (loop #5) had 2 loop(s) fused and 2 loop(s) serialized as part\n", | |
" of the larger parallel loop (#5).\n", | |
"--------------------------------------------------------------------------------\n", | |
"--------------------------------------------------------------------------------\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# for fun, we can clearly see the bad fusion of loops 2 & 4\n", | |
"parallel_calc_fenv_orig.parallel_diagnostics(level=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "f731cc43", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"False\n" | |
] | |
} | |
], | |
"source": [ | |
"# The original calc_fenv() had a bad parfor fusion due to a numba bug.\n", | |
"# f() is a minimal reproducer that demonstrates the issue.\n", | |
"# This will be fixed by: https://github.com/numba/numba/pull/7582\n", | |
"\n", | |
"def f():\n", | |
" a = np.arange(2)\n", | |
" amx = a.max()\n", | |
" res = np.empty(len(a))\n", | |
" res[:] = amx\n", | |
" return res\n", | |
"\n", | |
"numba.set_num_threads(1)\n", | |
"f_parallel = njit(parallel=True)(f)\n", | |
"f_serial = njit(parallel=False)(f)\n", | |
"\n", | |
"print(np.all(f_parallel() == f_serial()))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "31d477e5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" \n", | |
"================================================================================\n", | |
" Parallel Accelerator Optimizing: Function f, \n", | |
"/tmp/ipykernel_883107/3926855332.py (5) \n", | |
"================================================================================\n", | |
"\n", | |
"\n", | |
"Parallel loop listing for Function f, /tmp/ipykernel_883107/3926855332.py (5) \n", | |
"------------------------------|loop #ID\n", | |
"def f(): | \n", | |
" a = np.arange(2) | \n", | |
" amx = a.max()-------------| #13\n", | |
" res = np.empty(len(a)) | \n", | |
" res[:] = amx--------------| #11\n", | |
" return res | \n", | |
"------------------------------ After Optimisation ------------------------------\n", | |
"Parallel region 0:\n", | |
"+--12 (parallel, fused with loop(s): 11, 13)\n", | |
"\n", | |
"\n", | |
" \n", | |
"Parallel region 0 (loop #12) had 2 loop(s) fused.\n", | |
"--------------------------------------------------------------------------------\n", | |
"--------------------------------------------------------------------------------\n", | |
" \n" | |
] | |
} | |
], | |
"source": [ | |
"# as before, loops 11 & 13 have a bad fusion\n", | |
"f_parallel.parallel_diagnostics(level=1)" | |
] | |
} | |
], | |
"metadata": { | |
"interpreter": { | |
"hash": "5cd5cbe25001faa61ab76a271aef4113321f63d42e31cbdebc9d4a65270c2765" | |
}, | |
"kernelspec": { | |
"display_name": "MyEnv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment