Last active
March 8, 2022 22:30
-
-
Save ogrisel/1b430b2bf1e83173f6061676c62b9f18 to your computer and use it in GitHub Desktop.
Spectrum of the extended feature Gram matrix of an single hidden layer ReLU MLP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Empirical evaluation of the extended feature Gram matrix of a ReLU MLP | |
Here we try to estimate the spectrum of the H^\infty matrix as defined in: | |
Gradient Descent Provably Optimizes Over-parameterized Neural Networks (2018) | |
Simon S. Du, Xiyu Zhai, Barnabas Poczos, Aarti Singh | |
https://arxiv.org/abs/1810.02054 | |
Theorem 4.1 relies on the assumption that H^\infty has a strictly positive | |
minimum eigenvalue. The following computes an estimate of this eigenvalue | |
for a toy digits dataset with 1797 samples of 64 dimensions. In this case | |
we find that this assumption holds with \lambda_0 > 1.3e-2. | |
""" | |
from time import time | |
import numpy as np | |
import numba | |
import matplotlib.pyplot as plt | |
from sklearn.datasets import load_digits | |
from sklearn.preprocessing import normalize | |
# Workaround: https://github.com/numba/numba/issues/3341 | |
numba.config.THREADING_LAYER = 'workqueue' | |
@numba.jit(parallel=True) | |
def compute_h_inf(X, n_iter=int(1e4), seed=0): | |
n_samples, n_features = X.shape | |
H_inf = np.zeros(shape=(n_samples, n_samples), dtype=X.dtype) | |
W = np.random.RandomState(seed).randn(n_iter, n_features) | |
W_X = W @ X.T > 0 | |
Gram = X @ X.T | |
# Could be implemented with einsum as follows: | |
# np.einsum('ij,ki,kj->ij', Gram, W_X, W_X) / n_iter | |
# but using explicit numba loops makes it possible to use multi-threading. | |
scale = 1. / n_iter | |
for k in range(n_iter): | |
for i in numba.prange(n_samples): | |
for j in range(n_samples): | |
H_inf[i, j] += scale * Gram[i, j] * W_X[k, i] * W_X[k, j] | |
return H_inf | |
digits = load_digits() | |
X, y = digits.data, digits.target | |
n_samples, n_features = X.shape | |
print(f"Loaded digits data (n_samples={n_samples}, n_features={n_features})") | |
print("Normalizing X...") | |
X = normalize(X) | |
print("Computing the spectrum of the data Gram matrix", end="", flush=True) | |
t0 = time() | |
eigvals_gram = np.linalg.eigvalsh(X @ X.T) | |
print(f" done in {time() - t0:0.3f}s") | |
print(f"lambda_min(XX^T): {eigvals_gram.min():0.3e}") | |
# We only have 64 features, so the rank of this Gram matrix is bounded by 64. | |
fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True, constrained_layout=True, | |
figsize=(12, 8)) | |
ax0.semilogy(eigvals_gram[::-1]) | |
ax0.set_title('Spectrum of the data Gram matrix $XX^T$') | |
ax0.set_ylabel('Eigenvalue (logscale)') | |
for n_iter in [1_000, 10_000, 100_000]: | |
print(f"Computing extended feature Gram H_inf with n_iter={n_iter}...", | |
end="", flush=True) | |
t0 = time() | |
H_inf = compute_h_inf(X, n_iter=n_iter) | |
print(f" done in {time() - t0:0.3f}s") | |
print(f"H_inf.shape={H_inf.shape}") | |
print("Checking that H_inf is symmetric...", end="", flush=True) | |
np.testing.assert_allclose(H_inf, H_inf.T) | |
print(" ok") | |
print("Computing the spectrum of H_inf...", end="", flush=True) | |
t0 = time() | |
eigvals = np.linalg.eigvalsh(H_inf) | |
print(f" done in {time() - t0:0.3f}s") | |
print(f"lambda_min(H_inf): {eigvals.min():0.3e}") | |
ax1.semilogy(eigvals[::-1]) | |
ax1.set_title('Spectrum of the extended feature Gram matrix: $H^\infty$') | |
ax1.set_ylabel('Eigenvalue (logscale)') | |
ax1.set_xlabel('Eigenvalue rank') | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loaded digits data (n_samples=1797, n_features=64) | |
Normalizing X... | |
Computing the spectrum of the data Gram matrix done in 0.763s | |
lambda_min(XX^T): -2.796e-13 | |
Computing extended feature Gram H_inf with n_iter=1000... done in 2.793s | |
H_inf.shape=(1797, 1797) | |
Checking that H_inf is symmetric... ok | |
Computing the spectrum of H_inf... done in 0.592s | |
lambda_min(H_inf): 3.083e-03 | |
Computing extended feature Gram H_inf with n_iter=10000... done in 18.585s | |
H_inf.shape=(1797, 1797) | |
Checking that H_inf is symmetric... ok | |
Computing the spectrum of H_inf... done in 0.578s | |
lambda_min(H_inf): 1.112e-02 | |
Computing extended feature Gram H_inf with n_iter=100000... done in 209.125s | |
H_inf.shape=(1797, 1797) | |
Checking that H_inf is symmetric... ok | |
Computing the spectrum of H_inf... done in 0.607s | |
lambda_min(H_inf): 1.354e-02 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here are the spectrums (logscale):