Skip to content

Instantly share code, notes, and snippets.

@sylvchev
Last active April 15, 2020 08:28
Show Gist options
  • Save sylvchev/6c4e10a45417d67214f1204745895d29 to your computer and use it in GitHub Desktop.
Save sylvchev/6c4e10a45417d67214f1204745895d29 to your computer and use it in GitHub Desktop.
Verification regarding weighted means of SPD matrices
import pyriemann
import numpy as np
from pyriemann.utils.mean import mean_riemann
def generate_cov(Nt, Ne):
"""Generate a set of cavariances matrices for test purpose"""
rs = np.random.RandomState(1234)
diags = 2.0 + 0.1 * rs.randn(Nt, Ne)
A = 2*rs.rand(Ne, Ne) - 1
A /= np.atleast_2d(np.sqrt(np.sum(A**2, 1))).T
covmats = np.empty((Nt, Ne, Ne))
for i in range(Nt):
covmats[i] = np.dot(np.dot(A, np.diag(diags[i])), A.T)
return covmats, diags, A
covs, diags, A = generate_cov(4, 5)
mean1 = mean_riemann(covs, sample_weight=np.array([0.1, 0.2, 0.3, 0.4]))
covs2 = np.empty_like(covs)
for i in range(4):
covs2[i,:,:] = covs[i,:,:]*(i+1)/10
mean2 = mean_riemann(covs2)
print(mean1)
print(mean2)
mean_equi = mean_riemann(covs, sample_weight=np.array([0.25, 0.25, 0.25, 0.25]))
covs3 = np.empty_like(covs)
for i in range (4):
covs3[i,:,:] = covs[i,:,:]*0.25
mean_equi2 = mean_riemann(covs3)
print(mean_equi)
print(mean_equi2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment