import numpy as np def argmedian(p, axis=None, check_pmf=True): """Compute the median of a pmf or many pmfs Here, `p` represents one or many probability mass functions, and the index to a median will be returned for each pmf. - If `axis=None`, the flattened version of `p` will be considered a single pmf on its flattened index space. (cf `unravel_index`) - If `axis` is an integer, then we assume that `p.sum(axis=axis)` is an array of 1s, so that `p` contains many pmfs along that axis. The chosen location is just that index which makes the cdf closest to 0.5 -- could be a problem if the median ought to lie between bins. As with `argmin`, `take_along_axis` is a useful companion function. Arguments --------- p : ndarray Array of one or more pmfs along `axis` axis : int check_pmf : bool Returns ------- Indices to medians. """ p = np.asarray(p) if axis is None: p = np.ravel(p) if check_pmf: if (p < 0).any(): raise ValueError("argmedian: positivity error.") if not np.allclose(p.sum(axis=axis), 1): raise ValueError( "argmedian: Sum to 1 violation", np.abs(p.sum(axis=axis) - 1).max(), ) p = np.cumsum(p, axis=axis) ix = np.abs(p - 0.5).argmin(axis=axis) return ix if __name__ == "__main__": # tests rg = np.random.default_rng(0) p = [0.1] * 10 x = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0]) assert argmedian(p) == 4 assert x[argmedian(p)] p = [p] * 20 am = argmedian(p, axis=1) assert am.shape == (20,) assert x[am].shape == (20,) assert np.all(am == 4) assert np.all(x[am]) p = np.array(p).T am = argmedian(p, axis=0) assert am.shape == (20,) assert x[am].shape == (20,) assert np.all(am == 4) assert np.all(x[am]) p = [0.5, 0.5, 0, 0, 0] p = np.array([np.roll(p, i) for i in range(4)]) x = np.arange(5) am = argmedian(p, axis=1) assert am.shape == (4,) assert np.all(x[am] == x[:4]) rolls = rg.integers(0, 4, size=(11, 13)) p = [0.5, 0.5, 0, 0, 0] p = np.array( [[np.roll(p, rolls[i, j]) for j in range(13)] for i in range(11)] ) am = argmedian(p, axis=-1) assert am.shape == (11, 13) assert np.all(x[am] == rolls) p = p.transpose(1, 2, 0) am = argmedian(p, axis=1) assert am.shape == (13, 11) assert np.all(x[am] == rolls.T)