Skip to content

Instantly share code, notes, and snippets.

@cwindolf
Last active December 20, 2021 20:59
Show Gist options
  • Save cwindolf/498204babd1d2a7516b9573e88034348 to your computer and use it in GitHub Desktop.
Save cwindolf/498204babd1d2a7516b9573e88034348 to your computer and use it in GitHub Desktop.
argmedian.py: Compute the median of data distributed according to a pmf (or many pmfs)
import numpy as np
def argmedian(p, axis=None, check_pmf=True):
"""Compute the median of a pmf or many pmfs
Here, `p` represents one or many probability mass functions, and
the index to a median will be returned for each pmf.
- If `axis=None`, the flattened version of `p` will be considered
a single pmf on its flattened index space. (cf `unravel_index`)
- If `axis` is an integer, then we assume that `p.sum(axis=axis)`
is an array of 1s, so that `p` contains many pmfs along that
axis.
The chosen location is just that index which makes the cdf closest to
0.5 -- could be a problem if the median ought to lie between bins.
As with `argmin`, `take_along_axis` is a useful companion function.
Arguments
---------
p : ndarray
Array of one or more pmfs along `axis`
axis : int
check_pmf : bool
Returns
-------
Indices to medians.
"""
p = np.asarray(p)
if axis is None:
p = np.ravel(p)
if check_pmf:
if (p < 0).any():
raise ValueError("argmedian: positivity error.")
if not np.allclose(p.sum(axis=axis), 1):
raise ValueError(
"argmedian: Sum to 1 violation",
np.abs(p.sum(axis=axis) - 1).max(),
)
p = np.cumsum(p, axis=axis)
ix = np.abs(p - 0.5).argmin(axis=axis)
return ix
if __name__ == "__main__":
# tests
rg = np.random.default_rng(0)
p = [0.1] * 10
x = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
assert argmedian(p) == 4
assert x[argmedian(p)]
p = [p] * 20
am = argmedian(p, axis=1)
assert am.shape == (20,)
assert x[am].shape == (20,)
assert np.all(am == 4)
assert np.all(x[am])
p = np.array(p).T
am = argmedian(p, axis=0)
assert am.shape == (20,)
assert x[am].shape == (20,)
assert np.all(am == 4)
assert np.all(x[am])
p = [0.5, 0.5, 0, 0, 0]
p = np.array([np.roll(p, i) for i in range(4)])
x = np.arange(5)
am = argmedian(p, axis=1)
assert am.shape == (4,)
assert np.all(x[am] == x[:4])
rolls = rg.integers(0, 4, size=(11, 13))
p = [0.5, 0.5, 0, 0, 0]
p = np.array(
[[np.roll(p, rolls[i, j]) for j in range(13)] for i in range(11)]
)
am = argmedian(p, axis=-1)
assert am.shape == (11, 13)
assert np.all(x[am] == rolls)
p = p.transpose(1, 2, 0)
am = argmedian(p, axis=1)
assert am.shape == (13, 11)
assert np.all(x[am] == rolls.T)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment