Last active
October 6, 2024 06:10
-
-
Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
Efficient computation of a Kronecker - vector product (with multiple matrices).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numpy.random as npr | |
from functools import reduce | |
# Goal | |
# ---- | |
# Compute (As[0] kron As[1] kron ... As[-1]) @ v | |
# ==== HELPER FUNCTIONS ==== # | |
def unfold(tens, mode, dims): | |
""" | |
Unfolds tensor into matrix. | |
Parameters | |
---------- | |
tens : ndarray, tensor with shape == dims | |
mode : int, which axis to move to the front | |
dims : list, holds tensor shape | |
Returns | |
------- | |
matrix : ndarray, shape (dims[mode], prod(dims[/mode])) | |
""" | |
if mode == 0: | |
return tens.reshape(dims[0], -1) | |
else: | |
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1) | |
def refold(vec, mode, dims): | |
""" | |
Refolds vector into tensor. | |
Parameters | |
---------- | |
vec : ndarray, tensor with len == prod(dims) | |
mode : int, which axis was unfolded along. | |
dims : list, holds tensor shape | |
Returns | |
------- | |
tens : ndarray, tensor with shape == dims | |
""" | |
if mode == 0: | |
return vec.reshape(dims) | |
else: | |
# Reshape and then move dims[mode] back to its | |
# appropriate spot (undoing the `unfold` operation). | |
tens = vec.reshape( | |
[dims[mode]] + | |
[d for m, d in enumerate(dims) if m != mode] | |
) | |
return np.moveaxis(tens, 0, mode) | |
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== # | |
def kron_vec_prod(As, v): | |
""" | |
Computes matrix-vector multiplication between | |
matrix kron(As[0], As[1], ..., As[N]) and vector | |
v without forming the full kronecker product. | |
""" | |
dims = [A.shape[0] for A in As] | |
vt = v.reshape(dims) | |
for i, A in enumerate(As): | |
vt = refold(A @ unfold(vt, i, dims), i, dims) | |
return vt.ravel() | |
def kron_brute_force(As, v): | |
""" | |
Computes kron-matrix times vector by brute | |
force (instantiates the full kron product). | |
""" | |
return reduce(np.kron, As) @ v | |
# Quick demonstration. | |
if __name__ == "__main__": | |
# Create random problem. | |
_dims = [3, 3, 3, 3, 3, 3, 3, 3] | |
As = [npr.randn(d, d) for d in _dims] | |
v = npr.randn(np.prod(_dims)) | |
# Test accuracy. | |
actual = kron_vec_prod(As, v) | |
expected = kron_brute_force(As, v) | |
print(np.linalg.norm(actual - expected)) |
Very helpful codes! But does this only work for square matrices?
Yes I think this only works for square matrices. I'm not immediately sure how to extend it to the non-square case.
What if one just wanted the kron result without the subsequent vector product?
For extending to the non-square case, all we need is a minor change to the kron_vec_prod
method:
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
The modified code is as follows:
import numpy as np
import numpy.random as npr
from functools import reduce
# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v
# ==== HELPER FUNCTIONS ==== #
def unfold(tens, mode, dims):
"""
Unfolds tensor into matrix.
Parameters
----------
tens : ndarray, tensor with shape == dims
mode : int, which axis to move to the front
dims : list, holds tensor shape
Returns
-------
matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
"""
if mode == 0:
return tens.reshape(dims[0], -1)
else:
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)
def refold(vec, mode, dims):
"""
Refolds vector into tensor.
Parameters
----------
vec : ndarray, tensor with len == prod(dims)
mode : int, which axis was unfolded along.
dims : list, holds tensor shape
Returns
-------
tens : ndarray, tensor with shape == dims
"""
if mode == 0:
return vec.reshape(dims)
else:
# Reshape and then move dims[mode] back to its
# appropriate spot (undoing the `unfold` operation).
tens = vec.reshape(
[dims[mode]] +
[d for m, d in enumerate(dims) if m != mode]
)
return np.moveaxis(tens, 0, mode)
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== #
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
def kron_brute_force(As, v):
"""
Computes kron-matrix times vector by brute
force (instantiates the full kron product).
"""
return reduce(np.kron, As) @ v
# Quick demonstration.
if __name__ == "__main__":
# Create random problem.
_yaxes = [2, 3, 4]
_xaxes = [1, 2, 1]
# As = [np.ones((x,y)) for (x, y) in zip(_xaxes, _yaxes)]
As = [np.random.rand(x, y) for (x, y) in zip(_xaxes, _yaxes)]
v = np.ones((np.prod(_yaxes), ))
# Test accuracy.
actual = kron_vec_prod(As, v)
expected = kron_brute_force(As, v)
print(np.linalg.norm(actual - expected))
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some back-of-the-envelope calculations and comments on computational complexity. Suppose
M
square matrices of dimensionN x N
are kronecker-multiplied. The brute-force approach constructs aN^M x N^M
matrix which is multiplied by a vector withN^M
. In total, this results inO(N^{2 * M})
flops.The code above implements a faster approach, which involves
M
sequential matrix multiplies. Each matrix multiply is between anN x N
matrix and aN x N^{M-1}
matrix, which (assuming no Strassen-type algorithms for matrix-multiply) takesO(N^{M+1})
flops. Since this is repeatedM
times, the total computational cost isO(M * N^{M+1})
. This is often substantially faster as it avoids the factor of2
in the exponent.