Created
October 13, 2016 17:07
-
-
Save larsoner/7d311845d19c90f4de7b6dfd8278087c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.signal import butter, lfilter | |
from scipy import linalg | |
from scipy.fftpack import fft, ifft | |
import mne | |
from mne.filter import next_fast_len | |
def make_x_xt(ac): | |
len_trf = (ac.shape[2] + 1) / 2 | |
n_ch = ac.shape[0] | |
xxt = np.zeros([n_ch * len_trf] * 2) | |
for ch0 in range(n_ch): | |
for ch1 in range(n_ch): | |
xxt_temp = np.zeros((len_trf, len_trf)) | |
xxt_temp[0, :] = ac[ch0, ch1, len_trf - 1:] | |
xxt_temp[:, 0] = ac[ch0, ch1, len_trf - 1::-1] | |
for i in np.arange(1, len_trf): | |
xxt_temp[i, i:] = ac[ch0, ch1, len_trf - 1:-i] | |
xxt_temp[i:, i] = ac[ch0, ch1, len_trf - 1:i - 1:-1] | |
xxt[ch0 * len_trf:(ch0 + 1) * len_trf, | |
ch1 * len_trf:(ch1 + 1) * len_trf] = xxt_temp | |
return xxt | |
def trf_corr(x_in, x_out, fs, t_start, t_stop): | |
trf_start_ind = int(np.floor(t_start * fs)) | |
trf_stop_ind = int(np.floor(t_stop * fs)) | |
trf_inds = np.arange(trf_start_ind, trf_stop_ind + 1, dtype=int) | |
t_trf = trf_inds / float(fs) | |
len_trf = len(t_trf) | |
n_ch_in, len_sig = x_in.shape | |
n_ch_out = x_out.shape[0] | |
if t_stop <= t_start: | |
raise ValueError("t_stop must be after t_start") | |
# Eventually we could use rfft/irrft for these operations, | |
# but the memory layout is really annoying for doing X * Y.conj()... | |
x_in_fft = fft(x_in, next_fast_len(x_in.shape[-1] + len_trf - 1)) | |
x_out_fft = fft(x_out, next_fast_len(x_out.shape[-1] + len_trf - 1)) | |
# compute the autocorrelations | |
ac = np.zeros((n_ch_in, n_ch_in, len_trf * 2 - 1)) | |
for ch0 in range(n_ch_in): | |
for ch1 in np.arange(ch0, n_ch_in, dtype=int): | |
ac_temp = np.real(ifft(x_in_fft[ch0] * np.conj(x_in_fft[ch1]))) | |
ac[ch0, ch1] = np.append(ac_temp[-len_trf + 1:], ac_temp[:len_trf]) | |
if ch0 != ch1: | |
ac[ch1, ch0] = ac[ch0, ch1][::-1] | |
# compute the crosscorrelations | |
cc = np.zeros((n_ch_out, n_ch_in, len_trf)) | |
for ch_in in range(n_ch_in): | |
for ch_out in range(n_ch_out): | |
cc_temp = np.real(ifft(x_out_fft[ch_out] * | |
np.conj(x_in_fft[ch_in]))) | |
if trf_start_ind < 0 and trf_stop_ind + 1 > 0: | |
cc[ch_out, ch_in] = np.append(cc_temp[trf_start_ind:], | |
cc_temp[:trf_stop_ind + 1]) | |
else: | |
cc[ch_out, ch_in] = cc_temp[trf_start_ind:trf_stop_ind + 1] | |
# make xxt and xy | |
x_xt = make_x_xt(ac) / len_sig | |
x_y = cc.reshape([n_ch_out, n_ch_in * len_trf]) / len_sig | |
return x_xt, x_y, t_trf | |
def trf_reg(x_xt, x_y, n_ch_in, lambda_=0, reg_type='ridge'): | |
n_ch_out = x_y.shape[0] | |
len_trf = x_y.shape[1] / n_ch_in | |
if reg_type == 'ridge': | |
reg = np.eye(x_xt.shape[0]) | |
elif reg_type == 'laplacian': | |
reg = np.diag(np.tile(np.hstack(([1], 2 * np.ones(len_trf - 2), [1])), | |
n_ch_in)) | |
reg += np.diag(np.tile(np.hstack((-np.ones(len_trf - 1), [0])), | |
n_ch_in)[:-1], 1) | |
reg += np.diag(np.tile(np.hstack((-np.ones(len_trf - 1), [0])), | |
n_ch_in)[:-1], -1) | |
else: | |
ValueError("reg_type must be either 'ridge' or 'laplacian'") | |
mat = x_xt + lambda_ * reg | |
w = linalg.lstsq(mat, x_y.T)[0].T | |
w = w.reshape([n_ch_out, n_ch_in, len_trf]) | |
return w | |
############################################################################### | |
# First make a demo input and output signal | |
rng = np.random.RandomState(0) | |
# signal parameters | |
fs = 200 | |
n_ch_in = 2 # e.g., number of audio sources in | |
n_ch_out = 5 # e.g., number of electrodes out | |
len_sig = fs * 120 | |
# trf parameters | |
trf_start = -100e-3 # -100e-3 | |
trf_stop = 300e-3 | |
# make the signals with some correlations | |
x_in = rng.randn(n_ch_in, len_sig) + rng.randn(1, len_sig) | |
# continuously mix the signals | |
ideal = np.zeros((n_ch_in, int(round((trf_stop - trf_start) * fs)) + 1)) | |
ideal[:, -int(round(trf_start * fs))] = 1. | |
x_in_filt = np.copy(x_in) | |
for ch in range(n_ch_in): | |
ba = butter(8, (ch + 1.) / (3. * n_ch_in), 'lowpass') | |
ideal[ch] = lfilter(ba[0], ba[1], ideal[ch]) | |
x_in_filt[ch] = lfilter(ba[0], ba[1], x_in[ch]) | |
w_in_out = rng.randn(n_ch_out, n_ch_in) | |
x_out = np.dot(w_in_out, x_in_filt) + 1 + rng.randn(n_ch_out, len_sig) | |
# this is what ideally would be recovered | |
ideal = w_in_out[..., np.newaxis] * ideal[np.newaxis] | |
############################################################################### | |
# Now solve for the TRF: w = (x * x.T + lam * reg) \ x * y | |
# First get XX^T and XY | |
x_xt, x_y, t_trf = trf_corr(x_in, x_out, fs, trf_start, trf_stop) | |
# Now do inverse with some regularization | |
w = trf_reg(x_xt, x_y, n_ch_in, 1e-1, reg_type='laplacian') | |
############################################################################### | |
# Plot the results | |
fig, axes = plt.subplots(n_ch_out, figsize=(3, 8)) | |
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] | |
for ai in range(n_ch_out): | |
h = list() | |
for si in range(n_ch_in): | |
h.append(axes[ai].plot( | |
t_trf, ideal[ai, si], color=colors[si], alpha=0.5, linewidth=2)[0]) | |
axes[ai].plot( | |
t_trf, w[ai, si], color=colors[si], linestyle='--', dashes=(10, 5)) | |
axes[ai].set(xlim=t_trf[[0, -1]], ylim=[w.min(), w.max()], | |
ylabel='Channel %d' % ai) | |
if ai == n_ch_out // 2: | |
leg = axes[ai].legend(h, ['Signal 0', 'Signal 1'], loc='upper right') | |
leg.set_frame_on(False) | |
axes[2] | |
axes[-1].set(xlabel='Time (sec)') | |
mne.viz.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment