Last active
January 2, 2022 19:46
-
-
Save dkirkby/a36ef2b097710db5500b7a17a01028bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# See https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array | |
# argpartition requires numpy >= 1.8.0 | |
# See also http://seanlaw.github.io/2020/01/03/finding-top-or-bottom-k-in-a-numpy-array/ | |
def argkmax1(a, k, axis=-1): | |
"""Return the indices of the k largest elements of a. | |
With k=1, this is identical to argmax except that it | |
returns an array of length 1 instead of a scalar. | |
""" | |
idx = np.argpartition(a, -k, axis=axis)[-k:] | |
return idx[np.argsort(a[idx])] | |
def argkmax2(a, k, axis=-1): | |
"""Return the indices of the k largest elements of a. | |
With k=1, this is identical to argmax except that it | |
returns an array of length 1 instead of a scalar. | |
""" | |
idx = np.empty(k, int) | |
amin = np.min(a) | |
save = np.empty(k, a.dtype) | |
for i in range(-1, -(k + 1), -1): | |
j = np.argmax(a) | |
idx[i] = j | |
save[i] = a[j] | |
a[j] = amin | |
a[idx] = save | |
return idx | |
# First form is faster for large k, second is faster for small k. | |
rng = np.random.RandomState(123) | |
R = rng.normal(size=10000000) | |
%time K = argkmax1(R, 5); # Wall time: 117 ms | |
%time K = argkmax2(R, 5); # Wall time: 66.6 ms | |
%time K = argkmax1(R, 15); # Wall time: 118 ms | |
%time K = argkmax2(R, 15); # Wall time: 165 ms | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment