Skip to content

Instantly share code, notes, and snippets.

@lgarrison
Last active June 24, 2025 20:14
Show Gist options
  • Save lgarrison/5a7cec176339736e0fec8e97e7959db8 to your computer and use it in GitHub Desktop.
Save lgarrison/5a7cec176339736e0fec8e97e7959db8 to your computer and use it in GitHub Desktop.
benchmarking different index matching algorithms
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "numba",
# "numpy",
# ]
# ///
# Example output:
# > uv run zipper.py
# findIndices with 1e+05 needles took 0.989 seconds
# zipper_search with 1e+05 needles took 0.5674 seconds
# zipper_search_simple with 1e+05 needles took 0.08588 seconds
# findIndices with 1e+06 needles took 1.348 seconds
# zipper_search with 1e+06 needles took 0.6168 seconds
# zipper_search_simple with 1e+06 needles took 0.1017 seconds
# findIndices with 5e+06 needles took 3.629 seconds
# zipper_search with 5e+06 needles took 0.9215 seconds
# zipper_search_simple with 5e+06 needles took 0.1719 seconds
# findIndices with 1e+07 needles took 10.44 seconds
# zipper_search with 1e+07 needles took 1.309 seconds
# zipper_search_simple with 1e+07 needles took 0.1879 seconds
from timeit import default_timer
import numpy as np
import numba as nb
@nb.jit
def zipper_search_simple(A, B):
"""Search for elements of B in A, assuming A and B are sorted."""
indices = np.empty(len(B), dtype=np.int64)
i = 0 # index for A
for j in range(len(B)):
while i < len(A) and A[i] < B[j]:
i += 1
if i < len(A) and A[i] == B[j]:
indices[j] = i
else:
indices[j] = -1
return indices
@nb.jit
def zipper_search(A, B, sorterA, sorterB):
"""Search for elements of B in A, using indices from sorterA and sorterB."""
indices = np.empty(len(B), dtype=np.int64)
i = 0 # index for A
for j in range(len(B)):
while i < len(A) and A[sorterA[i]] < B[sorterB[j]]:
i += 1
if i < len(A) and A[sorterA[i]] == B[sorterB[j]]:
indices[sorterB[j]] = sorterA[i]
else:
indices[sorterB[j]] = -1
return indices
def findIndices(set_all, set_find, algorithm='auto'):
"""
Find the indices of one set of integers in another.
This function is useful for finding the indices of a set of IDs in another large array of IDs.
For example, the load() function does not return halos in the requested order if halo_ids are
passed to the function. This function makes it easy to obtain the original ordering.
Two algorithms are implemented: the standard np.where() function which searches each ID
separately, and an algorithm based on np.searchsorted(). The latter is faster when set_find is
large (greater than 100). By default, the best algorithm is chosen.
Parameters
---------------------------
set_all: array_like
A set of unique integer numbers (does not need to be sorted).
set_find: array_like
The numbers to be found in set_all (does not need to be sorted, and must be smaller or
equal in size to set_all).
algorithm: str
If ``auto``, the function automatically chooses the best algorithm. If ``individual``, the
np.where() function is used. If ``array``, the np.searchsorted() function is used.
Returns
-------
idx_find: array_like
Array of indices of the same size as set_find, pointing to set_all.
"""
# We start by finding the mask of the all set in the find set (not the other way around). If
# the input arrays are unique, each element needs to be matched ones in the all-set.
n_find = len(set_find)
mask = np.isin(set_all, set_find)
if np.count_nonzero(mask) != n_find:
print('findIndices(): Elements not found:')
print(set_find[np.logical_not(mask)])
msg = (
'At least one ID could not be found in set, or an ID was duplicated (looking for %d IDs, found %d matches).'
% (n_find, np.count_nonzero(mask))
)
raise Exception(msg)
if algorithm == 'auto':
use_individual = n_find < 100
elif algorithm == 'individual':
use_individual = True
elif algorithm == 'array':
use_individual = False
if use_individual:
idx_find = np.zeros((n_find), int)
for i in range(n_find):
idx_find[i] = np.where(set_find[i] == set_all)[0][0]
else:
matches = set_all[mask]
sorter = np.argsort(matches)
matches = matches[sorter]
mask_idx = (np.where(mask)[0])[sorter]
matches = np.searchsorted(matches, set_find)
idx_find = mask_idx[matches]
return idx_find
N = 10**7
rng = np.random.default_rng(seed=0xBEEF)
haystack_orig = rng.integers(np.iinfo(np.uint64).max, size=N, dtype=np.uint64)
for N_needles in (N // 100, N // 10, N // 2, N):
# select a subset of the haystack to be needles
haystack = haystack_orig.copy()
needles = rng.choice(haystack, size=N_needles, replace=False)
# warm-up
findIndices(haystack, needles)
t = -default_timer()
fres = findIndices(haystack, needles, algorithm='array')
t += default_timer()
print(f'findIndices with {N_needles:.1g} needles took {t:.4g} seconds')
# warm-up
zipper_search(haystack, needles, np.argsort(haystack), np.argsort(needles))
t = -default_timer()
zres = zipper_search(haystack, needles, np.argsort(haystack), np.argsort(needles))
t += default_timer()
print(f'zipper_search with {N_needles:.1g} needles took {t:.4g} seconds')
assert np.array_equal(fres, zres)
# warm-up
zipper_search_simple(haystack, needles)
t = -default_timer()
haystack.sort()
needles.sort()
zres = zipper_search_simple(haystack, needles)
t += default_timer()
print(f'zipper_search_simple with {N_needles:.1g} needles took {t:.4g} seconds')
# assert np.array_equal(fres, zres) # NB this is not expected to be true
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment