Skip to content

Instantly share code, notes, and snippets.

@cdipaolo
Created July 29, 2019 05:23
Show Gist options
  • Save cdipaolo/3d1eaf946ebbcbc93fe081c4b1359075 to your computer and use it in GitHub Desktop.
Save cdipaolo/3d1eaf946ebbcbc93fe081c4b1359075 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.linalg
def kronsolve(A, y):
"""Given a list of positive definite matrices A = [Ar,...,A1]
and a vector y, return x so that (Ar o ... o A1)x = y where o is
the Kronecker product.
Note
----
This function assumes each A[i] is square, and that the vector
y is of suitable dimension to make the matrix-vector product
sensible.
This method is implemented iteratively internally.
Examples
--------
>>> np.random.seed(1)
>>> A = [np.random.randn(5,5)/np.sqrt(5)]; A[0] = A[0] @ A[0].T
>>> y = np.random.randn(5)
>>> expected = np.linalg.solve(A[0], y)
>>> actual = kronsolve(A, y)
>>> assert np.allclose(expected, actual)
>>> np.random.seed(1)
>>> A = [np.random.randn(5,5), np.random.randn(5,5)]
>>> A = [a @ a.T / 5 for a in A]
>>> y = np.random.randn(5*5)
>>> expected = np.linalg.solve(np.kron(A[0], A[1]), y)
>>> actual = kronsolve(A, y)
>>> assert np.allclose(expected, actual)
>>> np.random.seed(1)
>>> A = [np.random.randn(i,i) for i in range(1,7)]
>>> A = [a @ a.T / a.shape[0] for a in A]
>>> y = np.random.randn(6*5*4*3*2*1)
>>> Akron = np.kron(A[0],
... np.kron(A[1],
... np.kron(A[2],
... np.kron(A[3],
... np.kron(A[4], A[5])))))
>>> expected = np.linalg.solve(Akron, y)
>>> actual = kronsolve(A, y)
>>> assert np.allclose(expected, actual)
Parameters
----------
A : list[np.array]
A list of square numpy arrays.
y : np.array
A numpy array.
Returns
-------
np.array
A numpy of the same shape as y, which is equal to the
solution of the linear system (Ar o ... o A1)x = y.
"""
y = y.copy()
for Ai in A:
ci = Ai.shape[0]
shape = (-1,ci) + y.shape[1:]
y = y.reshape(*shape, order="F")
Ai_chol = scipy.linalg.cho_factor(Ai, check_finite=False)
y = scipy.linalg.cho_solve(
Ai_chol,
y.transpose((1,0) + tuple(range(len(shape)))[2:]),
check_finite=False,
).transpose((1,0) + tuple(range(len(shape)))[2:])
return y.flatten("F")
if __name__ == "__main__":
import timeit
def kron(A):
if len(A) == 1:
return A[0]
return np.kron(A[0],kron(A[1:]))
def kronsolve_naive(A,y):
Achol = scipy.linalg.cho_factor(kron(A), check_finite=False)
return scipy.linalg.cho_solve(Achol, y, check_finite=False)
setup = """
from __main__ import {func}
import numpy as np
np.random.seed(1)
r = {r}
c = {c}
A = [np.random.randn(c,c) for _ in range(r)]
A = [a @ a.T / c for a in A]
y = np.random.randn(c**r)
"""
c = 2
N = 250
for r in [3,4,5,6,7,8,9,10]:
naive = np.min(timeit.repeat("kronsolve_naive(A, y)", setup.format(func="kronsolve_naive",r=r,c=c), number=N))
ours = np.min(timeit.repeat("kronsolve(A, y)", setup.format(func="kronsolve",r=r,c=c), number=N))
print("r={r}, c={c}, n={n}".format(r=r, c=c, n=c**r))
print("\tNaive Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, naive * 1e3))
print("\tOur Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, ours * 1e3))
print("\tOur Implementation Speedup: {:0.2f}%".format(-(ours - naive) / naive * 100))
@cdipaolo
Copy link
Author

Benchmarking results below. It appears this implementation is advantageous in real terms as soon as n > 64, and quickly becomes far superior. Note that we do assume the input always comes in the form of a list of [Ar, ..., A1]. If you were given both this list as well as the actual Kronecker product matrix Ar * ... * A2 * A1, the advantage of this method would take slightly longer to see.

$ python3 kronsolve.py           
r=3, c=2, n=8
	Naive Implementation Runtime (ave of 250, best of 5): 16.42836ms
	Our   Implementation Runtime (ave of 250, best of 5): 19.03734ms
	Our   Implementation Speedup:                        -15.88%
r=4, c=2, n=16
	Naive Implementation Runtime (ave of 250, best of 5): 21.09975ms
	Our   Implementation Runtime (ave of 250, best of 5): 26.37621ms
	Our   Implementation Speedup:                        -25.01%
r=5, c=2, n=32
	Naive Implementation Runtime (ave of 250, best of 5): 26.45909ms
	Our   Implementation Runtime (ave of 250, best of 5): 33.67146ms
	Our   Implementation Speedup:                        -27.26%
r=6, c=2, n=64
	Naive Implementation Runtime (ave of 250, best of 5): 43.24035ms
	Our   Implementation Runtime (ave of 250, best of 5): 42.89743ms
	Our   Implementation Speedup:                        0.79%
r=7, c=2, n=128
	Naive Implementation Runtime (ave of 250, best of 5): 88.76367ms
	Our   Implementation Runtime (ave of 250, best of 5): 51.34740ms
	Our   Implementation Speedup:                        42.15%
r=8, c=2, n=256
	Naive Implementation Runtime (ave of 250, best of 5): 228.24591ms
	Our   Implementation Runtime (ave of 250, best of 5): 66.51551ms
	Our   Implementation Speedup:                        70.86%
r=9, c=2, n=512
	Naive Implementation Runtime (ave of 250, best of 5): 1153.86708ms
	Our   Implementation Runtime (ave of 250, best of 5): 84.68914ms
	Our   Implementation Speedup:                        92.66%
r=10, c=2, n=1024
	Naive Implementation Runtime (ave of 250, best of 5): 7292.92084ms
	Our   Implementation Runtime (ave of 250, best of 5): 159.07176ms
	Our   Implementation Speedup:                        97.82%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment