Skip to content

Instantly share code, notes, and snippets.

@agoose77
Created March 16, 2022 14:19
Show Gist options
  • Save agoose77/299ea87d32b8f98f803faade57430a0e to your computer and use it in GitHub Desktop.
Save agoose77/299ea87d32b8f98f803faade57430a0e to your computer and use it in GitHub Desktop.
import math
@nb.njit
def _reservoir_sample_prepare(n, w, i, out):
k = len(out)
out[:] = np.arange(k, dtype=out.dtype)
i[()] = k - 1
w[()] = 1
@nb.njit
def _reservoir_sample_step(n, w, i, random, out):
k = len(out)
j = 0
while i < n:
# Do not proceed if we don't have enough RNG values
n_rand_available = len(random) - j
if n_rand_available < 3:
return True
# Compute number of skips
rand = random[j]
j += 1
w *= np.exp(np.log(rand) / k)
# Compute sampled index
rand = random[j]
j += 1
i += math.floor(np.log(rand) / np.log(1 - w)) + 1
if i >= n:
break
# Compute target index
rand = random[j]
j += 1
l = math.floor(rand * k)
out[l] = i
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment