-
-
Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
| import numpy as np | |
| import numpy.random as npr | |
| from functools import reduce | |
| # Goal | |
| # ---- | |
| # Compute (As[0] kron As[1] kron ... As[-1]) @ v | |
| # ==== HELPER FUNCTIONS ==== # | |
| def unfold(tens, mode, dims): | |
| """ | |
| Unfolds tensor into matrix. | |
| Parameters | |
| ---------- | |
| tens : ndarray, tensor with shape == dims | |
| mode : int, which axis to move to the front | |
| dims : list, holds tensor shape | |
| Returns | |
| ------- | |
| matrix : ndarray, shape (dims[mode], prod(dims[/mode])) | |
| """ | |
| if mode == 0: | |
| return tens.reshape(dims[0], -1) | |
| else: | |
| return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1) | |
| def refold(vec, mode, dims): | |
| """ | |
| Refolds vector into tensor. | |
| Parameters | |
| ---------- | |
| vec : ndarray, tensor with len == prod(dims) | |
| mode : int, which axis was unfolded along. | |
| dims : list, holds tensor shape | |
| Returns | |
| ------- | |
| tens : ndarray, tensor with shape == dims | |
| """ | |
| if mode == 0: | |
| return vec.reshape(dims) | |
| else: | |
| # Reshape and then move dims[mode] back to its | |
| # appropriate spot (undoing the `unfold` operation). | |
| tens = vec.reshape( | |
| [dims[mode]] + | |
| [d for m, d in enumerate(dims) if m != mode] | |
| ) | |
| return np.moveaxis(tens, 0, mode) | |
| # ==== KRON-VEC PRODUCT COMPUTATIONS ==== # | |
| def kron_vec_prod(As, v): | |
| """ | |
| Computes matrix-vector multiplication between | |
| matrix kron(As[0], As[1], ..., As[N]) and vector | |
| v without forming the full kronecker product. | |
| """ | |
| dims = [A.shape[0] for A in As] | |
| vt = v.reshape(dims) | |
| for i, A in enumerate(As): | |
| vt = refold(A @ unfold(vt, i, dims), i, dims) | |
| return vt.ravel() | |
| def kron_brute_force(As, v): | |
| """ | |
| Computes kron-matrix times vector by brute | |
| force (instantiates the full kron product). | |
| """ | |
| return reduce(np.kron, As) @ v | |
| # Quick demonstration. | |
| if __name__ == "__main__": | |
| # Create random problem. | |
| _dims = [3, 3, 3, 3, 3, 3, 3, 3] | |
| As = [npr.randn(d, d) for d in _dims] | |
| v = npr.randn(np.prod(_dims)) | |
| # Test accuracy. | |
| actual = kron_vec_prod(As, v) | |
| expected = kron_brute_force(As, v) | |
| print(np.linalg.norm(actual - expected)) |
Some back-of-the-envelope calculations and comments on computational complexity. Suppose M square matrices of dimension N x N are kronecker-multiplied. The brute-force approach constructs a N^M x N^M matrix which is multiplied by a vector with N^M. In total, this results in O(N^{2 * M}) flops.
The code above implements a faster approach, which involves M sequential matrix multiplies. Each matrix multiply is between an N x N matrix and a N x N^{M-1} matrix, which (assuming no Strassen-type algorithms for matrix-multiply) takes O(N^{M+1}) flops. Since this is repeated M times, the total computational cost is O(M * N^{M+1}). This is often substantially faster as it avoids the factor of 2 in the exponent.
Very helpful codes! But does this only work for square matrices?
Yes I think this only works for square matrices. I'm not immediately sure how to extend it to the non-square case.
What if one just wanted the kron result without the subsequent vector product?
For extending to the non-square case, all we need is a minor change to the kron_vec_prod method:
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
The modified code is as follows:
import numpy as np
import numpy.random as npr
from functools import reduce
# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v
# ==== HELPER FUNCTIONS ==== #
def unfold(tens, mode, dims):
"""
Unfolds tensor into matrix.
Parameters
----------
tens : ndarray, tensor with shape == dims
mode : int, which axis to move to the front
dims : list, holds tensor shape
Returns
-------
matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
"""
if mode == 0:
return tens.reshape(dims[0], -1)
else:
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)
def refold(vec, mode, dims):
"""
Refolds vector into tensor.
Parameters
----------
vec : ndarray, tensor with len == prod(dims)
mode : int, which axis was unfolded along.
dims : list, holds tensor shape
Returns
-------
tens : ndarray, tensor with shape == dims
"""
if mode == 0:
return vec.reshape(dims)
else:
# Reshape and then move dims[mode] back to its
# appropriate spot (undoing the `unfold` operation).
tens = vec.reshape(
[dims[mode]] +
[d for m, d in enumerate(dims) if m != mode]
)
return np.moveaxis(tens, 0, mode)
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== #
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
def kron_brute_force(As, v):
"""
Computes kron-matrix times vector by brute
force (instantiates the full kron product).
"""
return reduce(np.kron, As) @ v
# Quick demonstration.
if __name__ == "__main__":
# Create random problem.
_yaxes = [2, 3, 4]
_xaxes = [1, 2, 1]
# As = [np.ones((x,y)) for (x, y) in zip(_xaxes, _yaxes)]
As = [np.random.rand(x, y) for (x, y) in zip(_xaxes, _yaxes)]
v = np.ones((np.prod(_yaxes), ))
# Test accuracy.
actual = kron_vec_prod(As, v)
expected = kron_brute_force(As, v)
print(np.linalg.norm(actual - expected))
Speed comparison for
dims = [3, 3, 3, 3, 3, 3, 3, 3].Speed comparison for
dims = [20, 20, 20].