Last active
March 8, 2025 23:09
-
-
Save ayghri/0cf3987c0653a6316f8c3983965f6e7c to your computer and use it in GitHub Desktop.
A numpy-numba function optimized for sampling unique indices per row exluding provided indices
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Ayoub Ghriss, 2024 | |
import numpy as np | |
from numba import njit, prange, int64, int32 | |
@njit(int64[:, :](int64, int64, int64[:, :]), parallel=True, cache=True, nogil=True) | |
def sample_excluded_parallel(n: int, k: int, to_exclude: np.ndarray) -> np.ndarray: | |
""" | |
Generate random integers in parallel excluding specified values. | |
Parameters: | |
n (int): Number of samples. | |
k (int): Number of integers to generate per sample. | |
to_exclude (np.ndarray int64, (n,m)): values to exclude from the random generation. | |
Returns: | |
np.ndarray(int64) shape (n,k): Array of random integers with excluded values. | |
rand_ints[i] contains k unique integers in (0, n-1) that excludes to_exclude[i] | |
""" | |
rand_ints = np.zeros((n, k), dtype=np.int64) | |
num_excluded = to_exclude.shape[1] | |
for i in prange(n): | |
subset = np.random.choice(n - num_excluded, size=k, replace=False) | |
offset_excluded = to_exclude[i] - np.arange(num_excluded) | |
for j in prange(k): | |
subset[j] += np.sum(subset[j] >= offset_excluded) | |
rand_ints[i] = subset | |
return rand_ints | |
if __name__=="__main__": | |
n = 6 | |
m = 8 | |
k = 3 | |
to_exclude = np.array([np.random.choice(m,4,replace=False) for _ in range(n)]) | |
to_exclude = np.sort(to_exclude,axis=1) | |
sampled=sample_excluded_parallel(n,m,k,to_exclude) | |
print(to_excluded) | |
print(sampled) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment