import numpy as np def kronmatvec(A, x): """Given a list of matrices A = [Ar,...,A1] and a vector x, return (Ar o ... o A1)x where o is the Kronecker product. Note ---- This function assumes each A[i] is square, and that the vector x is of suitable dimension to make the matrix-vector product sensible. If all A[i].shape = (c,c) and x.shape = (n,), this function uses at most 2 (c / log c) * n log(n) flops. Examples -------- >>> A = [np.random.randn(5,5)] >>> x = np.random.randn(5) >>> expected = A[0] @ x >>> actual = kronmatvec(A, x) >>> assert np.allclose(expected, actual) >>> A = [np.random.randn(5,5), np.random.randn(5,5)] >>> x = np.random.randn(5*5) >>> expected = np.kron(A[0], A[1]) @ x >>> actual = kronmatvec(A, x) >>> assert np.allclose(expected, actual) >>> A = [np.random.randn(5,5), np.random.randn(4,4), np.random.randn(3,3)] >>> x = np.random.randn(5*4*3) >>> expected = np.kron(A[0], np.kron(A[1], A[2])) @ x >>> actual = kronmatvec(A, x) >>> assert np.allclose(expected, actual) Parameters ---------- A : list[np.array] A list of square numpy arrays. x : np.array A numpy array. Returns ------- np.array A numpy of the same shape as x, which is equal to the Kronecker product (Ar o ... o A1) times x. """ if len(A) == 1: return A[0] @ x n, nr = x.shape[0], A[0].shape[0] x = x.reshape(n//nr, nr, order="F").copy() x = x @ A[0].T for i in range(nr): x[:,i] = kronmatvec(A[1:], x[:,i]) return x.flatten("F") if __name__ == "__main__": # Benchmark speedup. import timeit def kron(A): if len(A) == 1: return A[0] return np.kron(A[0],kron(A[1:])) setup = """ from __main__ import {func} import numpy as np np.random.seed(1) r = {r} c = {c} A = [np.random.randn(c,c) for _ in range(r)] x = np.random.randn(c**r) """ c = 2 N = 250 for r in [6,7,8,9,10,11,12]: naive = np.min(timeit.repeat("kron(A) @ x", setup.format(func="kron",r=r,c=c), number=N)) ours = np.min(timeit.repeat("kronmatvec(A, x)", setup.format(func="kronmatvec",r=r,c=c), number=N)) print("r={r}, c={c}, n={n}".format(r=r, c=c, n=c**r)) print("\tNaive Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, naive * 1e3)) print("\tOur Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, ours * 1e3)) print("\tOur Implementation Speedup: {:0.2f}%".format(-(ours - naive) / naive * 100))