Skip to content

Instantly share code, notes, and snippets.

@cdipaolo
Last active July 27, 2019 18:32
Show Gist options
  • Save cdipaolo/c32de5828c99e684f0b57e0ce7b5b424 to your computer and use it in GitHub Desktop.
Save cdipaolo/c32de5828c99e684f0b57e0ce7b5b424 to your computer and use it in GitHub Desktop.
Computes (Ar * ... * A2 * A1) x Efficiently where * is the Kronecker product.
import numpy as np
def kronmatvec(A, x):
"""Given a list of matrices A = [Ar,...,A1] and a vector x,
return (Ar o ... o A1)x where o is the Kronecker product.
Note
----
This function assumes each A[i] is square, and that the vector
x is of suitable dimension to make the matrix-vector product
sensible.
If all A[i].shape = (c,c) and x.shape = (n,), this function
uses at most 2 (c / log c) * n log(n) flops.
Examples
--------
>>> A = [np.random.randn(5,5)]
>>> x = np.random.randn(5)
>>> expected = A[0] @ x
>>> actual = kronmatvec(A, x)
>>> assert np.allclose(expected, actual)
>>> A = [np.random.randn(5,5), np.random.randn(5,5)]
>>> x = np.random.randn(5*5)
>>> expected = np.kron(A[0], A[1]) @ x
>>> actual = kronmatvec(A, x)
>>> assert np.allclose(expected, actual)
>>> A = [np.random.randn(5,5), np.random.randn(4,4), np.random.randn(3,3)]
>>> x = np.random.randn(5*4*3)
>>> expected = np.kron(A[0], np.kron(A[1], A[2])) @ x
>>> actual = kronmatvec(A, x)
>>> assert np.allclose(expected, actual)
Parameters
----------
A : list[np.array]
A list of square numpy arrays.
x : np.array
A numpy array.
Returns
-------
np.array
A numpy of the same shape as x, which is equal to the
Kronecker product (Ar o ... o A1) times x.
"""
if len(A) == 1:
return A[0] @ x
n, nr = x.shape[0], A[0].shape[0]
x = x.reshape(n//nr, nr, order="F").copy()
x = x @ A[0].T
for i in range(nr):
x[:,i] = kronmatvec(A[1:], x[:,i])
return x.flatten("F")
if __name__ == "__main__":
# Benchmark speedup.
import timeit
def kron(A):
if len(A) == 1:
return A[0]
return np.kron(A[0],kron(A[1:]))
setup = """
from __main__ import {func}
import numpy as np
np.random.seed(1)
r = {r}
c = {c}
A = [np.random.randn(c,c) for _ in range(r)]
x = np.random.randn(c**r)
"""
c = 2
N = 250
for r in [6,7,8,9,10,11,12]:
naive = np.min(timeit.repeat("kron(A) @ x", setup.format(func="kron",r=r,c=c), number=N))
ours = np.min(timeit.repeat("kronmatvec(A, x)", setup.format(func="kronmatvec",r=r,c=c), number=N))
print("r={r}, c={c}, n={n}".format(r=r, c=c, n=c**r))
print("\tNaive Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, naive * 1e3))
print("\tOur Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, ours * 1e3))
print("\tOur Implementation Speedup: {:0.2f}%".format(-(ours - naive) / naive * 100))
@cdipaolo
Copy link
Author

cdipaolo commented Jul 21, 2019

Benchmarking results below. It appears this implementation is advantageous in real terms as soon as n > 1000, and quickly becomes far superior. Note that we do assume the input always comes in the form of a list of [Ar, ..., A1]. If you were given both this list as well as the actual Kronecker product matrix Ar * ... * A2 * A1, the advantage of this method would take slightly longer to see.

For a runtime analysis of the floating point operations of this method, see this Math StackExchange Question. In short, if n is the size of the input vector x and the sizes of all A[i] are bounded above by a constant, this formulation uses on the order of n log(n) floating point operations.

$ python3 kronmatvec.py
r=6, c=2, n=64
	Naive Implementation Runtime (ave of 250, best of 5): 22.91785ms
	Our   Implementation Runtime (ave of 250, best of 5): 67.42920ms
	Our Implementation Speedup: -194.22%
r=7, c=2, n=128
	Naive Implementation Runtime (ave of 250, best of 5): 42.45826ms
	Our   Implementation Runtime (ave of 250, best of 5): 130.58122ms
	Our Implementation Speedup: -207.55%
r=8, c=2, n=256
	Naive Implementation Runtime (ave of 250, best of 5): 65.68554ms
	Our   Implementation Runtime (ave of 250, best of 5): 274.95310ms
	Our Implementation Speedup: -318.59%
r=9, c=2, n=512
	Naive Implementation Runtime (ave of 250, best of 5): 228.06848ms
	Our   Implementation Runtime (ave of 250, best of 5): 570.64526ms
	Our Implementation Speedup: -150.21%
r=10, c=2, n=1024
	Naive Implementation Runtime (ave of 250, best of 5): 1369.34180ms
	Our   Implementation Runtime (ave of 250, best of 5): 1241.07181ms
	Our Implementation Speedup: 9.37%
r=11, c=2, n=2048
	Naive Implementation Runtime (ave of 250, best of 5): 7688.64983ms
	Our   Implementation Runtime (ave of 250, best of 5): 2986.38670ms
	Our Implementation Speedup: 61.16%
r=12, c=2, n=4096
	Naive Implementation Runtime (ave of 250, best of 5): 38280.55533ms
	Our   Implementation Runtime (ave of 250, best of 5): 5728.85020ms
	Our Implementation Speedup: 85.03%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment