Skip to content

Instantly share code, notes, and snippets.

@nschloe
Last active January 13, 2022 20:43
Show Gist options
  • Save nschloe/8c5039d8bc450aa001cf310e4fe43d97 to your computer and use it in GitHub Desktop.
Save nschloe/8c5039d8bc450aa001cf310e4fe43d97 to your computer and use it in GitHub Desktop.
Gram-Schmidt row-wise vs column-wise
import numpy as np
import perfplot
rng = np.random.default_rng(1)
def setup(n):
U = rng.random((n, 10))
return U, np.ascontiguousarray(U.T)
def gram_schmidt_cols(data):
V, _ = data
U = np.copy(V)
n = U.shape[1]
for i in range(n):
# for j in range(i):
# U[:, i] -= (U[:, i] @ U[:, j]) * U[:, j]
U[:, i] -= U[:, :i] @ (U[:, i].T @ U[:, :i])
U[:, i] /= np.sqrt((U[:, i] @ U[:, i]))
return U
def gram_schmidt_rows(data):
_, V = data
U = np.copy(V)
n = U.shape[0]
for i in range(n):
# for j in range(i):
# U[i] -= (U[i] @ U[j]) * U[j]
U[i] -= (U[:i] @ U[i]) @ U[:i]
U[i] /= np.sqrt(U[i] @ U[i])
return U.T
def modified_gram_schmidt_cols(data):
V, _ = data
U = np.copy(V)
n = U.shape[1]
for i in range(n):
U[:, i] /= np.sqrt(U[:, i] @ U[:, i])
for j in range(i + 1, n):
U[:, j] -= U[:, i] * (U[:, i] @ U[:, j])
return U
def modified_gram_schmidt_cols2(data):
V, _ = data
U = np.copy(V)
n = U.shape[1]
for i in range(n):
U[:, i] /= np.sqrt(U[:, i] @ U[:, i])
U[:, i + 1 :] -= np.outer(U[:, i], U[:, i].T @ U[:, i + 1 :])
return U
def modified_gram_schmidt_rows(data):
_, V = data
U = np.copy(V)
n = U.shape[0]
for i in range(n):
U[i] /= np.sqrt(U[i] @ U[i])
for j in range(i + 1, n):
U[j] -= U[i] * (U[i] @ U[j])
return U.T
def modified_gram_schmidt_rows2(data):
_, V = data
U = np.copy(V)
n = U.shape[0]
for i in range(n):
U[i] /= np.sqrt(U[i] @ U[i])
U[i + 1 :] -= np.outer(U[i + 1 :] @ U[i], U[i])
return U.T
b = perfplot.bench(
setup=setup,
kernels=[
gram_schmidt_cols,
gram_schmidt_rows,
modified_gram_schmidt_cols,
modified_gram_schmidt_cols2,
modified_gram_schmidt_rows,
modified_gram_schmidt_rows2,
],
n_range=[2 ** k for k in range(4, 24)],
title="10 vectors of length n",
xlabel="n",
)
b.save("out0.png")
b.show()
@nschloe
Copy link
Author

nschloe commented Jan 13, 2022

Row-wise is faster. Reason: U[j] is a contiguous block in memory, U[:, j] isn't.

out0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment