Last active
April 22, 2021 20:45
-
-
Save larsoner/6dcf350225a682d0a1b7c8dab8987c4e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import faulthandler | |
import os.path as op | |
import numpy as np | |
from scipy.signal import welch, coherence, unit_impulse | |
from matplotlib import pyplot as plt | |
import mne | |
from mne.simulation import simulate_raw, add_noise | |
from mne.datasets import sample | |
from mne.time_frequency import csd_morlet | |
from mne.beamformer import make_dics, apply_dics_csd | |
faulthandler.enable() | |
data_path = sample.data_path(download=False) | |
subjects_dir = op.join(data_path, 'subjects') | |
meg_path = op.join(data_path, 'MEG', 'sample') | |
raw_fname = op.join(meg_path, 'sample_audvis_raw.fif') | |
fwd_fname = op.join(meg_path, 'sample_audvis-meg-eeg-oct-6-fwd.fif') | |
cov_fname = op.join(meg_path, 'sample_audvis-cov.fif') | |
fwd = mne.read_forward_solution(fwd_fname) | |
rand = np.random.RandomState(42) | |
sfreq = 50. # Sampling frequency of the generated signal | |
n_samp = int(round(10. * sfreq)) | |
times = np.arange(n_samp) / sfreq # 10 seconds of signal | |
n_times = len(times) | |
t_rand = 0.001 # Variation in the instantaneous frequency of the signal | |
std = 0.1 # Std-dev of the random fluctuations added to the signal | |
snr = 10. # Signal-to-noise ratio. Decrease to add more noise. | |
phase = 45 # relative phase to use for the second signal (in degrees) | |
# pick one vertex per hemisphere (originals in MNE example: [[3554], [635]]) | |
vertidx = [[1000], [2000]] # indices into fwd to use | |
# phase=0: | |
# real_filter: [0.1 0.4] | |
# True: [3.3 5.1] | |
# False: [3.3 9.7] | |
# svd_complex: [9.4 9. ] | |
# svd: [9.4 9.3] | |
# | |
# phase=45: | |
# real_filter: [0.1 0.4] | |
# True: [3. 8.1] | |
# False: [4. 9.] | |
# svd_complex: [ 8. 11.4] | |
# svd: [ 8.2 11.5] | |
# | |
# phase=90: | |
# real_filter: [0. 0.3] | |
# True: [1.1 7.3] | |
# False: [4.2 7.3] | |
# svd_complex: [ 5.2 10.3] | |
# svd: [ 5.5 10.4] | |
# | |
# phase=180: | |
# real_filter: [0.4 0.2] | |
# True: [2.7 1.2] | |
# False: [4.8 2.9] | |
# svd_complex: [1.7 5.2] | |
# svd: [2. 5.2] | |
def coh_signal_gen(phase=0): | |
"""Generate an oscillating signal. | |
Returns | |
------- | |
signal : ndarray | |
The generated signal. | |
""" | |
base_freq = 10. # Base frequency of the oscillators in Hertz | |
n_times = len(times) | |
# Generate an oscillator with varying frequency and phase lag. | |
signal = np.sin(2.0 * np.pi * | |
(base_freq * np.arange(n_times) / sfreq + | |
np.cumsum(t_rand * rand.randn(n_times))) + | |
np.deg2rad(phase)) | |
# Add some random fluctuations to the signal. | |
signal += std * rand.randn(n_times) | |
# Scale the signal to be in the right order of magnitude (~100 nAm) | |
# for MEG data. | |
signal *= 100e-9 | |
return signal | |
signal1 = coh_signal_gen() | |
signal2 = coh_signal_gen(phase=phase) | |
fig, axes = plt.subplots(2, 2, figsize=(8, 4)) | |
# Plot the timeseries | |
ax = axes[0][0] | |
ax.plot(times, 1e9 * signal1, lw=0.5) | |
ax.set(xlabel='Time (s)', xlim=times[[0, -1]], ylabel='Amplitude (Am)', | |
title='Signal 1') | |
ax = axes[0][1] | |
ax.plot(times, 1e9 * signal2, lw=0.5) | |
ax.set(xlabel='Time (s)', xlim=times[[0, -1]], title='Signal 2') | |
# Power spectrum of the first timeseries | |
f, p = welch(signal1, fs=sfreq, nperseg=128, nfft=256) | |
ax = axes[1][0] | |
# Only plot the first 100 frequencies | |
ax.plot(f[:100], 20 * np.log10(p[:100]), lw=1.) | |
ax.set(xlabel='Frequency (Hz)', xlim=f[[0, 99]], | |
ylabel='Power (dB)', title='Power spectrum of signal 1') | |
# Compute the coherence between the two timeseries | |
f, coh = coherence(signal1, signal2, fs=sfreq, nperseg=100, noverlap=64) | |
ax = axes[1][1] | |
ax.plot(f[:50], coh[:50], lw=1.) | |
ax.set(xlabel='Frequency (Hz)', xlim=f[[0, 49]], ylabel='Coherence', | |
title='Coherence between the timeseries') | |
fig.tight_layout() | |
vertices = [s['vertno'][v] for s, v in zip(fwd['src'], vertidx)] | |
data = np.vstack((signal1, signal2)) | |
stc_signal = mne.SourceEstimate( | |
data, vertices, tmin=0, tstep=1. / sfreq, subject='sample') | |
stc_noise = stc_signal * 0. | |
info = mne.io.read_info(raw_fname) | |
info.update(sfreq=sfreq, bads=[]) | |
picks = mne.pick_types(info, meg='grad', stim=True, exclude=()) | |
mne.pick_info(info, picks, copy=False) | |
cov = mne.cov.make_ad_hoc_cov(info) | |
cov['data'] *= (20. / snr) ** 2 # Scale the noise to achieve the desired SNR | |
stcs = [(stc_signal, unit_impulse(n_samp, dtype=int) * 1), | |
(stc_noise, unit_impulse(n_samp, dtype=int) * 2)] # stacked in time | |
duration = (len(stc_signal.times) * 2) / sfreq | |
raw = simulate_raw(info, stcs, forward=fwd) | |
add_noise(raw, cov, iir_filter=[4, -4, 0.8], random_state=rand) | |
events = mne.find_events(raw, initial_event=True) | |
tmax = (len(stc_signal.times) - 1) / sfreq | |
epochs = mne.Epochs(raw, events, event_id=dict(signal=1, noise=2), | |
tmin=0, tmax=tmax, baseline=None, preload=True) | |
assert len(epochs) == 2 # ensure that we got the two expected events | |
csd_signal = csd_morlet(epochs['signal'], frequencies=[10]) | |
filters = dict() | |
kwargs = dict(reg=0.05, pick_ori='max-power', depth=None, | |
inversion='matrix', weight_norm='unit-noise-gain') | |
filters['real_filter'] = make_dics( | |
info, fwd, csd_signal, real_filter=True, **kwargs) | |
filters['complex_filter'] = make_dics( | |
info, fwd, csd_signal, real_filter=False, **kwargs) | |
# filters['True'] = make_dics( | |
# info, fwd, csd_signal, real_ori=True, **kwargs) | |
# filters['False'] = make_dics( | |
# info, fwd, csd_signal, real_ori=False, **kwargs) | |
# filters['svd_complex'] = make_dics( | |
# info, fwd, csd_signal, real_ori='svd_complex', **kwargs) | |
# filters['svd'] = make_dics( | |
# info, fwd, csd_signal, real_ori='svd', **kwargs) | |
mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True, copy=False) | |
offsets = [0, fwd['src'][0]['nuse']] | |
idx = np.concatenate([offset + np.searchsorted(s['vertno'], v) | |
for offset, s, v in zip(offsets, fwd['src'], vertices)]) | |
loc = fwd['source_rr'][idx] | |
nn = fwd['source_nn'][idx] | |
for title, filt in filters.items(): | |
got_nn = filt['max_power_ori'][0, idx].real | |
angles = np.rad2deg(np.arccos(np.abs(np.sum(got_nn * nn, axis=1)))) | |
print(f'{title.rjust(11)}: {np.round(angles, 1)}') | |
# """ | |
power, f = apply_dics_csd(csd_signal, filt) | |
brain = power.plot( | |
'sample', subjects_dir=subjects_dir, hemi='both', | |
size=600, time_label=title, title=title) | |
brain.add_foci(vertices[0], coords_as_verts=True, hemi='lh', color='b') | |
brain.add_foci(vertices[1], coords_as_verts=True, hemi='rh', color='b') | |
# Rotate the view and add a title. | |
brain.show_view(view={'azimuth': 0, 'elevation': 0, 'distance': 550, | |
'focalpoint': [0, 0, 0]}) | |
# """ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment