Created
July 29, 2019 05:23
-
-
Save cdipaolo/3d1eaf946ebbcbc93fe081c4b1359075 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.linalg | |
def kronsolve(A, y): | |
"""Given a list of positive definite matrices A = [Ar,...,A1] | |
and a vector y, return x so that (Ar o ... o A1)x = y where o is | |
the Kronecker product. | |
Note | |
---- | |
This function assumes each A[i] is square, and that the vector | |
y is of suitable dimension to make the matrix-vector product | |
sensible. | |
This method is implemented iteratively internally. | |
Examples | |
-------- | |
>>> np.random.seed(1) | |
>>> A = [np.random.randn(5,5)/np.sqrt(5)]; A[0] = A[0] @ A[0].T | |
>>> y = np.random.randn(5) | |
>>> expected = np.linalg.solve(A[0], y) | |
>>> actual = kronsolve(A, y) | |
>>> assert np.allclose(expected, actual) | |
>>> np.random.seed(1) | |
>>> A = [np.random.randn(5,5), np.random.randn(5,5)] | |
>>> A = [a @ a.T / 5 for a in A] | |
>>> y = np.random.randn(5*5) | |
>>> expected = np.linalg.solve(np.kron(A[0], A[1]), y) | |
>>> actual = kronsolve(A, y) | |
>>> assert np.allclose(expected, actual) | |
>>> np.random.seed(1) | |
>>> A = [np.random.randn(i,i) for i in range(1,7)] | |
>>> A = [a @ a.T / a.shape[0] for a in A] | |
>>> y = np.random.randn(6*5*4*3*2*1) | |
>>> Akron = np.kron(A[0], | |
... np.kron(A[1], | |
... np.kron(A[2], | |
... np.kron(A[3], | |
... np.kron(A[4], A[5]))))) | |
>>> expected = np.linalg.solve(Akron, y) | |
>>> actual = kronsolve(A, y) | |
>>> assert np.allclose(expected, actual) | |
Parameters | |
---------- | |
A : list[np.array] | |
A list of square numpy arrays. | |
y : np.array | |
A numpy array. | |
Returns | |
------- | |
np.array | |
A numpy of the same shape as y, which is equal to the | |
solution of the linear system (Ar o ... o A1)x = y. | |
""" | |
y = y.copy() | |
for Ai in A: | |
ci = Ai.shape[0] | |
shape = (-1,ci) + y.shape[1:] | |
y = y.reshape(*shape, order="F") | |
Ai_chol = scipy.linalg.cho_factor(Ai, check_finite=False) | |
y = scipy.linalg.cho_solve( | |
Ai_chol, | |
y.transpose((1,0) + tuple(range(len(shape)))[2:]), | |
check_finite=False, | |
).transpose((1,0) + tuple(range(len(shape)))[2:]) | |
return y.flatten("F") | |
if __name__ == "__main__": | |
import timeit | |
def kron(A): | |
if len(A) == 1: | |
return A[0] | |
return np.kron(A[0],kron(A[1:])) | |
def kronsolve_naive(A,y): | |
Achol = scipy.linalg.cho_factor(kron(A), check_finite=False) | |
return scipy.linalg.cho_solve(Achol, y, check_finite=False) | |
setup = """ | |
from __main__ import {func} | |
import numpy as np | |
np.random.seed(1) | |
r = {r} | |
c = {c} | |
A = [np.random.randn(c,c) for _ in range(r)] | |
A = [a @ a.T / c for a in A] | |
y = np.random.randn(c**r) | |
""" | |
c = 2 | |
N = 250 | |
for r in [3,4,5,6,7,8,9,10]: | |
naive = np.min(timeit.repeat("kronsolve_naive(A, y)", setup.format(func="kronsolve_naive",r=r,c=c), number=N)) | |
ours = np.min(timeit.repeat("kronsolve(A, y)", setup.format(func="kronsolve",r=r,c=c), number=N)) | |
print("r={r}, c={c}, n={n}".format(r=r, c=c, n=c**r)) | |
print("\tNaive Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, naive * 1e3)) | |
print("\tOur Implementation Runtime (ave of {}, best of 5): {:0.5f}ms".format(N, ours * 1e3)) | |
print("\tOur Implementation Speedup: {:0.2f}%".format(-(ours - naive) / naive * 100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Benchmarking results below. It appears this implementation is advantageous in real terms as soon as n > 64, and quickly becomes far superior. Note that we do assume the input always comes in the form of a list of [Ar, ..., A1]. If you were given both this list as well as the actual Kronecker product matrix Ar * ... * A2 * A1, the advantage of this method would take slightly longer to see.