Last active
October 15, 2019 23:00
-
-
Save 0x0L/c863a5e09c202e3c6b58b86a48b01370 to your computer and use it in GitHub Desktop.
RIE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import sklearn.covariance as cov | |
def _norm2(x): | |
return np.square(x.real) + np.square(x.imag) | |
def rie_estimator(X): | |
T, N = X.shape | |
q = N / T | |
X = X - X.mean(0) | |
X /= X.std(0) | |
E = X.T @ X / T | |
λ, U = np.linalg.eigh(E) | |
z = λ - 1j / N**0.5 | |
stieltjes = 1.0 / (z[:, None] - λ) | |
np.fill_diagonal(stieltjes, 0) | |
stieltjes = stieltjes.mean(1) | |
ξ = λ / _norm2(1 - q + q * z * stieltjes) | |
rq = q**0.5 | |
λ0 = λ[0] | |
λ1 = λ0 * ((1 + rq) / (1 - rq))**2 | |
σ2 = λ0 / (1 - rq)**2 | |
g_mp = z + σ2 * (q - 1) - (z - λ0)**0.5 * (z - λ1)**0.5 | |
g_mp /= 2 * q * z * σ2 | |
Γ = σ2 / λ * _norm2(1 - q + q * z * g_mp) | |
ψ = ξ * np.maximum(Γ, 1) | |
s = np.sum(λ) / np.sum(ψ) | |
return ψ * s, U | |
N, T = 2490, 2500 | |
q = N / T | |
H = np.identity(N) + 5 * np.random.laplace(size=(N, N)) | |
Q = H.T @ H | |
v = np.diag(Q)**0.5 | |
Q = (1 / v[:, None]) * Q * (1 / v) | |
w, W = np.linalg.eigh(Q) | |
X = np.random.randn(T, N) @ H | |
X = (X - X.mean(0)) / X.std(0) | |
E = X.T @ X / X.shape[0] | |
w_emp, W_emp = np.linalg.eigh(E) | |
w_ledoit, _ = np.linalg.eigh(cov.ledoit_wolf(X)[0]) | |
w_oas, _ = np.linalg.eigh(cov.oas(X)[0]) | |
w_rie, _ = rie_estimator(X) | |
w_rie_s = np.sort(w_rie) | |
plt.figure(figsize=(12, 7)) | |
plt.plot(w, linewidth=4) | |
plt.plot(w_emp) | |
plt.plot(w_ledoit) | |
# plt.plot(w_oas) | |
plt.plot(w_rie) | |
plt.plot(w_rie_s, c='k') | |
error = lambda w: np.round(np.sum((Q - W_emp * w @ W_emp.T)**2), 2) | |
# plt.semilogy() | |
plt.legend([ | |
'ground truth', | |
f'empirical {error(w_emp)}', | |
f'ledoit {error(w_ledoit)}', | |
# f'oas {error(w_oas)}', | |
f'rie {error(w_rie)}', | |
f'rie_sorted {error(w_rie_s)}', | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment