Last active
December 7, 2021 01:45
-
-
Save cjauvin/661d5dcc11f313af3d7c55cfbdd8a624 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def weighted_quantile_type_7(a, q, weights=None, axis=None): | |
def _func_1d(arr, q, weights): | |
n = len(arr) | |
if weights is None: | |
weights = np.repeat(1 / n, n) | |
else: | |
weights = np.asarray(weights) | |
assert len(weights) == n | |
sorter = np.argsort(arr) | |
arr = arr[sorter] | |
weights = weights[sorter] | |
weights /= weights.sum() | |
weights_cum = np.append(0, weights.cumsum()) | |
res = [] | |
for p in q: | |
h = p * (n - 1) + 1 | |
u = np.maximum((h - 1) / n, np.minimum(h / n, weights_cum)) | |
v = u * n - h + 1 | |
w = np.diff(v) | |
res.append(sum(arr * w)) | |
return res | |
q = np.atleast_1d(q) | |
# We need to first partially flatten `a`, which is potentially multidimensional, | |
# into a 2D matrix with size [C x L]: | |
# * L: the size of one flattened 1d array that will be sent to _func_1d | |
# * C: the number of calls we will need to make (one per row of L values) | |
a = np.asarray(a) | |
shape = np.asarray(a.shape) | |
if axis is None: | |
flattened_1d_input_size = np.prod(shape) | |
else: | |
flattened_1d_input_size = np.prod(shape[np.asarray(axis)]) | |
total_size = np.prod(shape) | |
n_calls = total_size // flattened_1d_input_size | |
# Permute the input axes with np.moveaxis | |
def get_source_axis(): | |
# From the complete axis (0, 1, ..ndims), iteratively take every | |
# element from user axis param, and move it at the end. | |
# Ex, with n=3 and axis=0 : (1, 2, 0) | |
# with n=3 and axis=(0,2): (1, 0, 2) | |
n = a.ndim | |
src_axis = list(range(n)) | |
for i in np.atleast_1d(axis if axis is not None else []): | |
src_axis.insert(n, src_axis.pop(i)) | |
return src_axis | |
src = get_source_axis() | |
dest = tuple(range(a.ndim)) | |
b = np.moveaxis(a, src, dest) | |
# Reshape into 2D matrix: [n_calls] X [1d input size] | |
b = b.reshape(n_calls, flattened_1d_input_size) | |
# Call 1D func for every row of `b`; each call will return a result of size q | |
r = np.apply_along_axis(_func_1d, 0, b.T, q, weights) | |
# Reshape results back into shape: |q| x (original shape minus input axis) | |
if axis is None: | |
r = r.squeeze() | |
else: | |
r = r.reshape(len(q), *np.delete(shape, axis)) | |
return r | |
############################## | |
# Base 1D case, non-weighted # | |
############################## | |
a = [1, 1.9, 2.2, 3 , 3.7, 4.1, 5] | |
q = [.2, .4, .6, .8] | |
wq_equal_weights = weighted_quantile_type_7(a, q) | |
assert np.allclose(wq_equal_weights, np.quantile(a, q)) | |
print(wq_equal_weights) | |
########################## | |
# Base 1D case, weighted # | |
########################## | |
weights = [0.05, 0.05, 0.1 , 0.15, 0.15, 0.25, 0.25] | |
wq_with_weights = weighted_quantile_type_7(a, q, weights) | |
print(wq_with_weights) | |
####################################### | |
# Multidimensional case, non-weighted # | |
####################################### | |
a = np.random.rand(3, 4, 5) | |
q = [0.1, 0.9] | |
nq = len(q) | |
for (ax, shp) in [(None, (nq,)), | |
(0, (nq, 4, 5)), | |
(2, (nq, 3, 4)), | |
((0,2), (nq, 4)), | |
((0,1,2), (nq,))]: | |
mine = weighted_quantile_type_7(a, q, axis=ax) | |
theirs = np.quantile(a, q, axis=ax) | |
assert np.allclose(mine, theirs) | |
assert mine.shape == shp | |
assert theirs.shape == shp | |
print(ax, mine) | |
################################### | |
# Multidimensional case, weighted # | |
################################### | |
ax = (0, 2) | |
# Since ax=(0, 2), the weights size must be: [size of dim0 x size of dim2] = 3 x 5 = 15 | |
weights = np.random.random(15) | |
print(weighted_quantile_type_7(a, q, weights, axis=ax)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment