Last active
December 20, 2021 20:59
-
-
Save cwindolf/498204babd1d2a7516b9573e88034348 to your computer and use it in GitHub Desktop.
argmedian.py: Compute the median of data distributed according to a pmf (or many pmfs)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def argmedian(p, axis=None, check_pmf=True): | |
"""Compute the median of a pmf or many pmfs | |
Here, `p` represents one or many probability mass functions, and | |
the index to a median will be returned for each pmf. | |
- If `axis=None`, the flattened version of `p` will be considered | |
a single pmf on its flattened index space. (cf `unravel_index`) | |
- If `axis` is an integer, then we assume that `p.sum(axis=axis)` | |
is an array of 1s, so that `p` contains many pmfs along that | |
axis. | |
The chosen location is just that index which makes the cdf closest to | |
0.5 -- could be a problem if the median ought to lie between bins. | |
As with `argmin`, `take_along_axis` is a useful companion function. | |
Arguments | |
--------- | |
p : ndarray | |
Array of one or more pmfs along `axis` | |
axis : int | |
check_pmf : bool | |
Returns | |
------- | |
Indices to medians. | |
""" | |
p = np.asarray(p) | |
if axis is None: | |
p = np.ravel(p) | |
if check_pmf: | |
if (p < 0).any(): | |
raise ValueError("argmedian: positivity error.") | |
if not np.allclose(p.sum(axis=axis), 1): | |
raise ValueError( | |
"argmedian: Sum to 1 violation", | |
np.abs(p.sum(axis=axis) - 1).max(), | |
) | |
p = np.cumsum(p, axis=axis) | |
ix = np.abs(p - 0.5).argmin(axis=axis) | |
return ix | |
if __name__ == "__main__": | |
# tests | |
rg = np.random.default_rng(0) | |
p = [0.1] * 10 | |
x = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]) | |
assert argmedian(p) == 4 | |
assert x[argmedian(p)] | |
p = [p] * 20 | |
am = argmedian(p, axis=1) | |
assert am.shape == (20,) | |
assert x[am].shape == (20,) | |
assert np.all(am == 4) | |
assert np.all(x[am]) | |
p = np.array(p).T | |
am = argmedian(p, axis=0) | |
assert am.shape == (20,) | |
assert x[am].shape == (20,) | |
assert np.all(am == 4) | |
assert np.all(x[am]) | |
p = [0.5, 0.5, 0, 0, 0] | |
p = np.array([np.roll(p, i) for i in range(4)]) | |
x = np.arange(5) | |
am = argmedian(p, axis=1) | |
assert am.shape == (4,) | |
assert np.all(x[am] == x[:4]) | |
rolls = rg.integers(0, 4, size=(11, 13)) | |
p = [0.5, 0.5, 0, 0, 0] | |
p = np.array( | |
[[np.roll(p, rolls[i, j]) for j in range(13)] for i in range(11)] | |
) | |
am = argmedian(p, axis=-1) | |
assert am.shape == (11, 13) | |
assert np.all(x[am] == rolls) | |
p = p.transpose(1, 2, 0) | |
am = argmedian(p, axis=1) | |
assert am.shape == (13, 11) | |
assert np.all(x[am] == rolls.T) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment