Created
December 28, 2020 16:51
-
-
Save kingjr/980c94acc2ad611072fb3b69bbe7a508 to your computer and use it in GitHub Desktop.
Fast CPU AUC with regular masked array
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import jit | |
@jit | |
def fast_auc(ytrue_sorted): | |
nfalse = np.zeros(ytrue_sorted.shape[1:]) | |
auc = np.zeros(ytrue_sorted.shape[1:]) | |
n = len(ytrue_sorted) | |
for y_i in ytrue_sorted: | |
nfalse += (1 - y_i) | |
auc += y_i * nfalse | |
auc /= (nfalse * (n - nfalse)) | |
return auc | |
def mask_array(x, mask): | |
assert np.array_equal(x.shape, mask.shape) | |
assert mask.dtype == bool | |
sums = np.unique(mask.sum(0)) | |
assert len(sums) == 1 | |
shape = np.r_[sums[0], mask.shape[1:]] | |
x = x.T[mask.T].reshape(*shape[::-1]).T | |
return x | |
if __name__ == "__main__": | |
y_pred = np.random.randn(100, 20, 20) | |
y_true = np.random.rand(100) > .5 | |
# sort y_true according to y_pred | |
y_sort = np.argsort(y_pred, axis=0) | |
y_true_sorted = y_true[y_sort] | |
# potentially mask for subscoring | |
mask = np.random.rand(100, 20, 20)>.5 | |
y_true_sorted_masked = mask_array(y_true_sorted, mask) | |
# compute AUC | |
auc = fast_auc(y_true_sorted_masked) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment