Last active
June 24, 2025 20:14
-
-
Save lgarrison/5a7cec176339736e0fec8e97e7959db8 to your computer and use it in GitHub Desktop.
benchmarking different index matching algorithms
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# requires-python = ">=3.11" | |
# dependencies = [ | |
# "numba", | |
# "numpy", | |
# ] | |
# /// | |
# Example output: | |
# > uv run zipper.py | |
# findIndices with 1e+05 needles took 0.989 seconds | |
# zipper_search with 1e+05 needles took 0.5674 seconds | |
# zipper_search_simple with 1e+05 needles took 0.08588 seconds | |
# findIndices with 1e+06 needles took 1.348 seconds | |
# zipper_search with 1e+06 needles took 0.6168 seconds | |
# zipper_search_simple with 1e+06 needles took 0.1017 seconds | |
# findIndices with 5e+06 needles took 3.629 seconds | |
# zipper_search with 5e+06 needles took 0.9215 seconds | |
# zipper_search_simple with 5e+06 needles took 0.1719 seconds | |
# findIndices with 1e+07 needles took 10.44 seconds | |
# zipper_search with 1e+07 needles took 1.309 seconds | |
# zipper_search_simple with 1e+07 needles took 0.1879 seconds | |
from timeit import default_timer | |
import numpy as np | |
import numba as nb | |
@nb.jit | |
def zipper_search_simple(A, B): | |
"""Search for elements of B in A, assuming A and B are sorted.""" | |
indices = np.empty(len(B), dtype=np.int64) | |
i = 0 # index for A | |
for j in range(len(B)): | |
while i < len(A) and A[i] < B[j]: | |
i += 1 | |
if i < len(A) and A[i] == B[j]: | |
indices[j] = i | |
else: | |
indices[j] = -1 | |
return indices | |
@nb.jit | |
def zipper_search(A, B, sorterA, sorterB): | |
"""Search for elements of B in A, using indices from sorterA and sorterB.""" | |
indices = np.empty(len(B), dtype=np.int64) | |
i = 0 # index for A | |
for j in range(len(B)): | |
while i < len(A) and A[sorterA[i]] < B[sorterB[j]]: | |
i += 1 | |
if i < len(A) and A[sorterA[i]] == B[sorterB[j]]: | |
indices[sorterB[j]] = sorterA[i] | |
else: | |
indices[sorterB[j]] = -1 | |
return indices | |
def findIndices(set_all, set_find, algorithm='auto'): | |
""" | |
Find the indices of one set of integers in another. | |
This function is useful for finding the indices of a set of IDs in another large array of IDs. | |
For example, the load() function does not return halos in the requested order if halo_ids are | |
passed to the function. This function makes it easy to obtain the original ordering. | |
Two algorithms are implemented: the standard np.where() function which searches each ID | |
separately, and an algorithm based on np.searchsorted(). The latter is faster when set_find is | |
large (greater than 100). By default, the best algorithm is chosen. | |
Parameters | |
--------------------------- | |
set_all: array_like | |
A set of unique integer numbers (does not need to be sorted). | |
set_find: array_like | |
The numbers to be found in set_all (does not need to be sorted, and must be smaller or | |
equal in size to set_all). | |
algorithm: str | |
If ``auto``, the function automatically chooses the best algorithm. If ``individual``, the | |
np.where() function is used. If ``array``, the np.searchsorted() function is used. | |
Returns | |
------- | |
idx_find: array_like | |
Array of indices of the same size as set_find, pointing to set_all. | |
""" | |
# We start by finding the mask of the all set in the find set (not the other way around). If | |
# the input arrays are unique, each element needs to be matched ones in the all-set. | |
n_find = len(set_find) | |
mask = np.isin(set_all, set_find) | |
if np.count_nonzero(mask) != n_find: | |
print('findIndices(): Elements not found:') | |
print(set_find[np.logical_not(mask)]) | |
msg = ( | |
'At least one ID could not be found in set, or an ID was duplicated (looking for %d IDs, found %d matches).' | |
% (n_find, np.count_nonzero(mask)) | |
) | |
raise Exception(msg) | |
if algorithm == 'auto': | |
use_individual = n_find < 100 | |
elif algorithm == 'individual': | |
use_individual = True | |
elif algorithm == 'array': | |
use_individual = False | |
if use_individual: | |
idx_find = np.zeros((n_find), int) | |
for i in range(n_find): | |
idx_find[i] = np.where(set_find[i] == set_all)[0][0] | |
else: | |
matches = set_all[mask] | |
sorter = np.argsort(matches) | |
matches = matches[sorter] | |
mask_idx = (np.where(mask)[0])[sorter] | |
matches = np.searchsorted(matches, set_find) | |
idx_find = mask_idx[matches] | |
return idx_find | |
N = 10**7 | |
rng = np.random.default_rng(seed=0xBEEF) | |
haystack_orig = rng.integers(np.iinfo(np.uint64).max, size=N, dtype=np.uint64) | |
for N_needles in (N // 100, N // 10, N // 2, N): | |
# select a subset of the haystack to be needles | |
haystack = haystack_orig.copy() | |
needles = rng.choice(haystack, size=N_needles, replace=False) | |
# warm-up | |
findIndices(haystack, needles) | |
t = -default_timer() | |
fres = findIndices(haystack, needles, algorithm='array') | |
t += default_timer() | |
print(f'findIndices with {N_needles:.1g} needles took {t:.4g} seconds') | |
# warm-up | |
zipper_search(haystack, needles, np.argsort(haystack), np.argsort(needles)) | |
t = -default_timer() | |
zres = zipper_search(haystack, needles, np.argsort(haystack), np.argsort(needles)) | |
t += default_timer() | |
print(f'zipper_search with {N_needles:.1g} needles took {t:.4g} seconds') | |
assert np.array_equal(fres, zres) | |
# warm-up | |
zipper_search_simple(haystack, needles) | |
t = -default_timer() | |
haystack.sort() | |
needles.sort() | |
zres = zipper_search_simple(haystack, needles) | |
t += default_timer() | |
print(f'zipper_search_simple with {N_needles:.1g} needles took {t:.4g} seconds') | |
# assert np.array_equal(fres, zres) # NB this is not expected to be true |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment