import numpy as np

def kronmatvec(A, x):
    """Given a list of matrices A = [Ar,...,A1] and a vector x,
    return (Ar o ... o A1)x where o is the Kronecker product.

    Note
    ----
    This function assumes each A[i] is square, and that the vector
    x is of suitable dimension to make the matrix-vector product
    sensible.

    If all A[i].shape = (c,c) and x.shape = (n,), this function
    uses at most 2 (c / log c) * n log(n) flops.

    Examples
    --------

    >>> A = [np.random.randn(5,5)]
    >>> x = np.random.randn(5)
    >>> expected = A[0] @ x
    >>> actual = kronmatvec(A, x)
    >>> assert np.allclose(expected, actual)

    >>> A = [np.random.randn(5,5), np.random.randn(5,5)]
    >>> x = np.random.randn(5*5)
    >>> expected = np.kron(A[0], A[1]) @ x
    >>> actual = kronmatvec(A, x)
    >>> assert np.allclose(expected, actual)

    >>> A = [np.random.randn(5,5), np.random.randn(4,4), np.random.randn(3,3)]
    >>> x = np.random.randn(5*4*3)
    >>> expected = np.kron(A[0], np.kron(A[1], A[2])) @ x
    >>> actual = kronmatvec(A, x)
    >>> assert np.allclose(expected, actual)

    Parameters
    ----------
    A : list[np.array]
        A list of square numpy arrays.
    x : np.array
        A numpy array.

    Returns
    -------
    np.array
        A numpy of the same shape as x, which is equal to the
        Kronecker product (Ar o ... o A1) times x.
    """
    if len(A) == 1:
        return A[0] @ x

    n, nr = x.shape[0], A[0].shape[0]
    x = x.reshape(n//nr, nr, order="F").copy()
    x = x @ A[0].T

    for i in range(nr):
        x[:,i] = kronmatvec(A[1:], x[:,i])

    return x.flatten("F")


if __name__ == "__main__":
    # Benchmark speedup.
    import timeit

    def kron(A):
        if len(A) == 1:
            return A[0]
        return np.kron(A[0],kron(A[1:]))

    setup = """
from __main__ import {func}
import numpy as np
np.random.seed(1)
r = {r}
c = {c}
A = [np.random.randn(c,c) for _ in range(r)]
x = np.random.randn(c**r)
    """

    c = 2
    N = 250
    for r in [6,7,8,9,10,11,12]:
        naive = np.min(timeit.repeat("kron(A) @ x", setup.format(func="kron",r=r,c=c), number=N))
        ours  = np.min(timeit.repeat("kronmatvec(A, x)", setup.format(func="kronmatvec",r=r,c=c), number=N))
        print("r={r}, c={c}, n={n}".format(r=r, c=c, n=c**r))
        print("\tNaive Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, naive * 1e3))
        print("\tOur   Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, ours * 1e3))
        print("\tOur Implementation Speedup: {:0.2f}%".format(-(ours - naive) / naive * 100))