Skip to content

Instantly share code, notes, and snippets.

@ayghri
Last active March 8, 2025 23:09
Show Gist options
  • Save ayghri/0cf3987c0653a6316f8c3983965f6e7c to your computer and use it in GitHub Desktop.
Save ayghri/0cf3987c0653a6316f8c3983965f6e7c to your computer and use it in GitHub Desktop.
A numpy-numba function optimized for sampling unique indices per row exluding provided indices
# Ayoub Ghriss, 2024
import numpy as np
from numba import njit, prange, int64, int32
@njit(int64[:, :](int64, int64, int64[:, :]), parallel=True, cache=True, nogil=True)
def sample_excluded_parallel(n: int, k: int, to_exclude: np.ndarray) -> np.ndarray:
"""
Generate random integers in parallel excluding specified values.
Parameters:
n (int): Number of samples.
k (int): Number of integers to generate per sample.
to_exclude (np.ndarray int64, (n,m)): values to exclude from the random generation.
Returns:
np.ndarray(int64) shape (n,k): Array of random integers with excluded values.
rand_ints[i] contains k unique integers in (0, n-1) that excludes to_exclude[i]
"""
rand_ints = np.zeros((n, k), dtype=np.int64)
num_excluded = to_exclude.shape[1]
for i in prange(n):
subset = np.random.choice(n - num_excluded, size=k, replace=False)
offset_excluded = to_exclude[i] - np.arange(num_excluded)
for j in prange(k):
subset[j] += np.sum(subset[j] >= offset_excluded)
rand_ints[i] = subset
return rand_ints
if __name__=="__main__":
n = 6
m = 8
k = 3
to_exclude = np.array([np.random.choice(m,4,replace=False) for _ in range(n)])
to_exclude = np.sort(to_exclude,axis=1)
sampled=sample_excluded_parallel(n,m,k,to_exclude)
print(to_excluded)
print(sampled)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment