import numpy as np


def argmedian(p, axis=None, check_pmf=True):
    """Compute the median of a pmf or many pmfs

    Here, `p` represents one or many probability mass functions, and
    the index to a median will be returned for each pmf.

        - If `axis=None`, the flattened version of `p` will be considered
          a single pmf on its flattened index space. (cf `unravel_index`)
        - If `axis` is an integer, then we assume that `p.sum(axis=axis)`
          is an array of 1s, so that `p` contains many pmfs along that
          axis.

    The chosen location is just that index which makes the cdf closest to
    0.5 -- could be a problem if the median ought to lie between bins.

    As with `argmin`, `take_along_axis` is a useful companion function.

    Arguments
    ---------
    p : ndarray
        Array of one or more pmfs along `axis`
    axis : int
    check_pmf : bool

    Returns
    -------
    Indices to medians.
    """
    p = np.asarray(p)

    if axis is None:
        p = np.ravel(p)

    if check_pmf:
        if (p < 0).any():
            raise ValueError("argmedian: positivity error.")
        if not np.allclose(p.sum(axis=axis), 1):
            raise ValueError(
                "argmedian: Sum to 1 violation",
                np.abs(p.sum(axis=axis) - 1).max(),
            )

    p = np.cumsum(p, axis=axis)
    ix = np.abs(p - 0.5).argmin(axis=axis)
    return ix


if __name__ == "__main__":
    # tests
    rg = np.random.default_rng(0)

    p = [0.1] * 10
    x = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
    assert argmedian(p) == 4
    assert x[argmedian(p)]

    p = [p] * 20
    am = argmedian(p, axis=1)
    assert am.shape == (20,)
    assert x[am].shape == (20,)
    assert np.all(am == 4)
    assert np.all(x[am])

    p = np.array(p).T
    am = argmedian(p, axis=0)
    assert am.shape == (20,)
    assert x[am].shape == (20,)
    assert np.all(am == 4)
    assert np.all(x[am])

    p = [0.5, 0.5, 0, 0, 0]
    p = np.array([np.roll(p, i) for i in range(4)])
    x = np.arange(5)
    am = argmedian(p, axis=1)
    assert am.shape == (4,)
    assert np.all(x[am] == x[:4])

    rolls = rg.integers(0, 4, size=(11, 13))
    p = [0.5, 0.5, 0, 0, 0]
    p = np.array(
        [[np.roll(p, rolls[i, j]) for j in range(13)] for i in range(11)]
    )
    am = argmedian(p, axis=-1)
    assert am.shape == (11, 13)
    assert np.all(x[am] == rolls)

    p = p.transpose(1, 2, 0)
    am = argmedian(p, axis=1)
    assert am.shape == (13, 11)
    assert np.all(x[am] == rolls.T)