Skip to content

Instantly share code, notes, and snippets.

Last active February 17, 2025 08:46
Show Gist options
  • Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
Save ahwillia/f65bc70cb30206d4eadec857b98c4065 to your computer and use it in GitHub Desktop.
Efficient computation of a Kronecker - vector product (with multiple matrices).
import numpy as np
import numpy.random as npr
from functools import reduce
# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v
# ==== HELPER FUNCTIONS ==== #
def unfold(tens, mode, dims):
Unfolds tensor into matrix.
tens : ndarray, tensor with shape == dims
mode : int, which axis to move to the front
dims : list, holds tensor shape
matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
if mode == 0:
return tens.reshape(dims[0], -1)
return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)
def refold(vec, mode, dims):
Refolds vector into tensor.
vec : ndarray, tensor with len == prod(dims)
mode : int, which axis was unfolded along.
dims : list, holds tensor shape
tens : ndarray, tensor with shape == dims
if mode == 0:
return vec.reshape(dims)
# Reshape and then move dims[mode] back to its
# appropriate spot (undoing the `unfold` operation).
tens = vec.reshape(
[dims[mode]] +
[d for m, d in enumerate(dims) if m != mode]
return np.moveaxis(tens, 0, mode)
def kron_vec_prod(As, v):
Computes matrix-vector multiplication between
matrix kron(As[0], As[1], ..., As[N]) and vector
v without forming the full kronecker product.
dims = [A.shape[0] for A in As]
vt = v.reshape(dims)
for i, A in enumerate(As):
vt = refold(A @ unfold(vt, i, dims), i, dims)
return vt.ravel()
def kron_brute_force(As, v):
Computes kron-matrix times vector by brute
force (instantiates the full kron product).
return reduce(np.kron, As) @ v
# Quick demonstration.
if __name__ == "__main__":
# Create random problem.
_dims = [3, 3, 3, 3, 3, 3, 3, 3]
As = [npr.randn(d, d) for d in _dims]
v = npr.randn(
# Test accuracy.
actual = kron_vec_prod(As, v)
expected = kron_brute_force(As, v)
print(np.linalg.norm(actual - expected))
Copy link

ahwillia commented Nov 6, 2019

Speed comparison for dims = [3, 3, 3, 3, 3, 3, 3, 3].

%timeit kron_brute_force(As, v)                                                                                         
946 ms ± 17.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit kron_vec_prod(As, v)                                                                                            
299 µs ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Speed comparison for dims = [20, 20, 20].

%timeit kron_brute_force(As, v)                                                                                         
1.07 s ± 13.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit kron_vec_prod(As, v)                                                                                            
105 µs ± 845 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Copy link

ahwillia commented Nov 6, 2019

Some back-of-the-envelope calculations and comments on computational complexity. Suppose M square matrices of dimension N x N are kronecker-multiplied. The brute-force approach constructs a N^M x N^M matrix which is multiplied by a vector with N^M. In total, this results in O(N^{2 * M}) flops.

The code above implements a faster approach, which involves M sequential matrix multiplies. Each matrix multiply is between an N x N matrix and a N x N^{M-1} matrix, which (assuming no Strassen-type algorithms for matrix-multiply) takes O(N^{M+1}) flops. Since this is repeated M times, the total computational cost is O(M * N^{M+1}). This is often substantially faster as it avoids the factor of 2 in the exponent.

Copy link

weiT1993 commented Sep 3, 2021

Very helpful codes! But does this only work for square matrices?

Copy link

ahwillia commented Sep 3, 2021

Yes I think this only works for square matrices. I'm not immediately sure how to extend it to the non-square case.

Copy link

renatomello commented Nov 10, 2021

What if one just wanted the kron result without the subsequent vector product?

Copy link

sambroy commented Aug 14, 2023

For extending to the non-square case, all we need is a minor change to the kron_vec_prod method:

def kron_vec_prod(As, v):
    Computes matrix-vector multiplication between
    matrix kron(As[0], As[1], ..., As[N]) and vector
    v without forming the full kronecker product.
    dims = [A.shape[1] for A in As]
    vt = v.reshape(dims)
    dims_in = dims
    for i, A in enumerate(As):
        # change the ith entry of dims to A.shape[0]
        dims_fin = np.copy(dims_in)
        dims_fin[i] = A.shape[0]
        vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
        dims_in = np.copy(dims_fin)
    return vt.ravel()

The modified code is as follows:

import numpy as np
import numpy.random as npr
from functools import reduce

# Goal
# ----
# Compute (As[0] kron As[1] kron ... As[-1]) @ v

# ==== HELPER FUNCTIONS ==== #

def unfold(tens, mode, dims):
    Unfolds tensor into matrix.

    tens : ndarray, tensor with shape == dims
    mode : int, which axis to move to the front
    dims : list, holds tensor shape

    matrix : ndarray, shape (dims[mode], prod(dims[/mode]))
    if mode == 0:
        return tens.reshape(dims[0], -1)
        return np.moveaxis(tens, mode, 0).reshape(dims[mode], -1)

def refold(vec, mode, dims):
    Refolds vector into tensor.

    vec : ndarray, tensor with len == prod(dims)
    mode : int, which axis was unfolded along.
    dims : list, holds tensor shape

    tens : ndarray, tensor with shape == dims
    if mode == 0:
        return vec.reshape(dims)
        # Reshape and then move dims[mode] back to its
        # appropriate spot (undoing the `unfold` operation).
        tens = vec.reshape(
            [dims[mode]] +
            [d for m, d in enumerate(dims) if m != mode]
        return np.moveaxis(tens, 0, mode)


def kron_vec_prod(As, v):
    Computes matrix-vector multiplication between
    matrix kron(As[0], As[1], ..., As[N]) and vector
    v without forming the full kronecker product.
    dims = [A.shape[1] for A in As]
    vt = v.reshape(dims)
    dims_in = dims
    for i, A in enumerate(As):
        # change the ith entry of dims to A.shape[0]
        dims_fin = np.copy(dims_in)
        dims_fin[i] = A.shape[0]
        vt = refold(A @ unfold(vt, i, dims_in), i, dims_fin)
        dims_in = np.copy(dims_fin)
    return vt.ravel()

def kron_brute_force(As, v):
    Computes kron-matrix times vector by brute
    force (instantiates the full kron product).
    return reduce(np.kron, As) @ v

# Quick demonstration.
if __name__ == "__main__":

    # Create random problem.
    _yaxes = [2, 3, 4]
    _xaxes = [1, 2, 1]
    # As = [np.ones((x,y)) for (x, y) in zip(_xaxes, _yaxes)]
    As = [np.random.rand(x, y) for (x, y) in zip(_xaxes, _yaxes)]

    v = np.ones((, ))

    # Test accuracy.
    actual = kron_vec_prod(As, v)
    expected = kron_brute_force(As, v)
    print(np.linalg.norm(actual - expected))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment