-
-
Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
import numpy as np | |
import numpy.random as npr | |
from functools import reduce | |
# Goal | |
# ---- | |
# Compute (As[0] kron As[1] kron ... As[-1]) @ v | |
# ==== HELPER FUNCTIONS ==== # | |
def unfold(tens, mode, dims): | |
""" | |
Unfolds tensor into matrix. | |
Parameters | |
---------- | |
tens : ndarray, tensor with shape == dims | |
mode : int, which axis to move to the front | |
dims : list, holds tensor shape | |
Returns | |
------- | |
matrix : ndarray, shape (dims[mode], prod(dims[/mode])) | |
""" | |
if mode == 0: | |
return tens.reshape(dims[0], -1) | |
else: | |
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1) | |
def refold(vec, mode, dims): | |
""" | |
Refolds vector into tensor. | |
Parameters | |
---------- | |
vec : ndarray, tensor with len == prod(dims) | |
mode : int, which axis was unfolded along. | |
dims : list, holds tensor shape | |
Returns | |
------- | |
tens : ndarray, tensor with shape == dims | |
""" | |
if mode == 0: | |
return vec.reshape(dims) | |
else: | |
# Reshape and then move dims[mode] back to its | |
# appropriate spot (undoing the `unfold` operation). | |
tens = vec.reshape( | |
[dims[mode]] + | |
[d for m, d in enumerate(dims) if m != mode] | |
) | |
return np.moveaxis(tens, 0, mode) | |
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== # | |
def kron_vec_prod(As, v): | |
""" | |
Computes matrix-vector multiplication between | |
matrix kron(As[0], As[1], ..., As[N]) and vector | |
v without forming the full kronecker product. | |
""" | |
dims = [A.shape[0] for A in As] | |
vt = v.reshape(dims) | |
for i, A in enumerate(As): | |
vt = refold(A @ unfold(vt, i, dims), i, dims) | |
return vt.ravel() | |
def kron_brute_force(As, v): | |
""" | |
Computes kron-matrix times vector by brute | |
force (instantiates the full kron product). | |
""" | |
return reduce(np.kron, As) @ v | |
# Quick demonstration. | |
if __name__ == "__main__": | |
# Create random problem. | |
_dims = [3, 3, 3, 3, 3, 3, 3, 3] | |
As = [npr.randn(d, d) for d in _dims] | |
v = npr.randn(np.prod(_dims)) | |
# Test accuracy. | |
actual = kron_vec_prod(As, v) | |
expected = kron_brute_force(As, v) | |
print(np.linalg.norm(actual - expected)) |
Some back-of-the-envelope calculations and comments on computational complexity. Suppose M
square matrices of dimension N x N
are kronecker-multiplied. The brute-force approach constructs a N^M x N^M
matrix which is multiplied by a vector with N^M
. In total, this results in O(N^{2 * M})
flops.
The code above implements a faster approach, which involves M
sequential matrix multiplies. Each matrix multiply is between an N x N
matrix and a N x N^{M-1}
matrix, which (assuming no Strassen-type algorithms for matrix-multiply) takes O(N^{M+1})
flops. Since this is repeated M
times, the total computational cost is O(M * N^{M+1})
. This is often substantially faster as it avoids the factor of 2
in the exponent.
Very helpful codes! But does this only work for square matrices?
Yes I think this only works for square matrices. I'm not immediately sure how to extend it to the non-square case.
What if one just wanted the kron result without the subsequent vector product?
For extending to the non-square case, all we need is a minor change to the kron_vec_prod
method:
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
The modified code is as follows:
import numpy as np
import numpy.random as npr
from functools import reduce
# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v
# ==== HELPER FUNCTIONS ==== #
def unfold(tens, mode, dims):
"""
Unfolds tensor into matrix.
Parameters
----------
tens : ndarray, tensor with shape == dims
mode : int, which axis to move to the front
dims : list, holds tensor shape
Returns
-------
matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
"""
if mode == 0:
return tens.reshape(dims[0], -1)
else:
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)
def refold(vec, mode, dims):
"""
Refolds vector into tensor.
Parameters
----------
vec : ndarray, tensor with len == prod(dims)
mode : int, which axis was unfolded along.
dims : list, holds tensor shape
Returns
-------
tens : ndarray, tensor with shape == dims
"""
if mode == 0:
return vec.reshape(dims)
else:
# Reshape and then move dims[mode] back to its
# appropriate spot (undoing the `unfold` operation).
tens = vec.reshape(
[dims[mode]] +
[d for m, d in enumerate(dims) if m != mode]
)
return np.moveaxis(tens, 0, mode)
# ==== KRON-VEC PRODUCT COMPUTATIONS ==== #
def kron_vec_prod(As, v):
"""
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
"""
dims = [A.shape[1] for A in As]
vt = v.reshape(dims)
dims_in = dims
for i, A in enumerate(As):
# change the ith entry of dims to A.shape[0]
dims_fin = np.copy(dims_in)
dims_fin[i] = A.shape[0]
vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
dims_in = np.copy(dims_fin)
return vt.ravel()
def kron_brute_force(As, v):
"""
Computes kron-matrix times vector by brute
force (instantiates the full kron product).
"""
return reduce(np.kron, As) @ v
# Quick demonstration.
if __name__ == "__main__":
# Create random problem.
_yaxes = [2, 3, 4]
_xaxes = [1, 2, 1]
# As = [np.ones((x,y)) for (x, y) in zip(_xaxes, _yaxes)]
As = [np.random.rand(x, y) for (x, y) in zip(_xaxes, _yaxes)]
v = np.ones((np.prod(_yaxes), ))
# Test accuracy.
actual = kron_vec_prod(As, v)
expected = kron_brute_force(As, v)
print(np.linalg.norm(actual - expected))
Speed comparison for
dims = [3, 3, 3, 3, 3, 3, 3, 3]
.Speed comparison for
dims = [20, 20, 20]
.