Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created October 13, 2016 17:07
Show Gist options
  • Save larsoner/7d311845d19c90f4de7b6dfd8278087c to your computer and use it in GitHub Desktop.
Save larsoner/7d311845d19c90f4de7b6dfd8278087c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
from scipy import linalg
from scipy.fftpack import fft, ifft
import mne
from mne.filter import next_fast_len
def make_x_xt(ac):
len_trf = (ac.shape[2] + 1) / 2
n_ch = ac.shape[0]
xxt = np.zeros([n_ch * len_trf] * 2)
for ch0 in range(n_ch):
for ch1 in range(n_ch):
xxt_temp = np.zeros((len_trf, len_trf))
xxt_temp[0, :] = ac[ch0, ch1, len_trf - 1:]
xxt_temp[:, 0] = ac[ch0, ch1, len_trf - 1::-1]
for i in np.arange(1, len_trf):
xxt_temp[i, i:] = ac[ch0, ch1, len_trf - 1:-i]
xxt_temp[i:, i] = ac[ch0, ch1, len_trf - 1:i - 1:-1]
xxt[ch0 * len_trf:(ch0 + 1) * len_trf,
ch1 * len_trf:(ch1 + 1) * len_trf] = xxt_temp
return xxt
def trf_corr(x_in, x_out, fs, t_start, t_stop):
trf_start_ind = int(np.floor(t_start * fs))
trf_stop_ind = int(np.floor(t_stop * fs))
trf_inds = np.arange(trf_start_ind, trf_stop_ind + 1, dtype=int)
t_trf = trf_inds / float(fs)
len_trf = len(t_trf)
n_ch_in, len_sig = x_in.shape
n_ch_out = x_out.shape[0]
if t_stop <= t_start:
raise ValueError("t_stop must be after t_start")
# Eventually we could use rfft/irrft for these operations,
# but the memory layout is really annoying for doing X * Y.conj()...
x_in_fft = fft(x_in, next_fast_len(x_in.shape[-1] + len_trf - 1))
x_out_fft = fft(x_out, next_fast_len(x_out.shape[-1] + len_trf - 1))
# compute the autocorrelations
ac = np.zeros((n_ch_in, n_ch_in, len_trf * 2 - 1))
for ch0 in range(n_ch_in):
for ch1 in np.arange(ch0, n_ch_in, dtype=int):
ac_temp = np.real(ifft(x_in_fft[ch0] * np.conj(x_in_fft[ch1])))
ac[ch0, ch1] = np.append(ac_temp[-len_trf + 1:], ac_temp[:len_trf])
if ch0 != ch1:
ac[ch1, ch0] = ac[ch0, ch1][::-1]
# compute the crosscorrelations
cc = np.zeros((n_ch_out, n_ch_in, len_trf))
for ch_in in range(n_ch_in):
for ch_out in range(n_ch_out):
cc_temp = np.real(ifft(x_out_fft[ch_out] *
np.conj(x_in_fft[ch_in])))
if trf_start_ind < 0 and trf_stop_ind + 1 > 0:
cc[ch_out, ch_in] = np.append(cc_temp[trf_start_ind:],
cc_temp[:trf_stop_ind + 1])
else:
cc[ch_out, ch_in] = cc_temp[trf_start_ind:trf_stop_ind + 1]
# make xxt and xy
x_xt = make_x_xt(ac) / len_sig
x_y = cc.reshape([n_ch_out, n_ch_in * len_trf]) / len_sig
return x_xt, x_y, t_trf
def trf_reg(x_xt, x_y, n_ch_in, lambda_=0, reg_type='ridge'):
n_ch_out = x_y.shape[0]
len_trf = x_y.shape[1] / n_ch_in
if reg_type == 'ridge':
reg = np.eye(x_xt.shape[0])
elif reg_type == 'laplacian':
reg = np.diag(np.tile(np.hstack(([1], 2 * np.ones(len_trf - 2), [1])),
n_ch_in))
reg += np.diag(np.tile(np.hstack((-np.ones(len_trf - 1), [0])),
n_ch_in)[:-1], 1)
reg += np.diag(np.tile(np.hstack((-np.ones(len_trf - 1), [0])),
n_ch_in)[:-1], -1)
else:
ValueError("reg_type must be either 'ridge' or 'laplacian'")
mat = x_xt + lambda_ * reg
w = linalg.lstsq(mat, x_y.T)[0].T
w = w.reshape([n_ch_out, n_ch_in, len_trf])
return w
###############################################################################
# First make a demo input and output signal
rng = np.random.RandomState(0)
# signal parameters
fs = 200
n_ch_in = 2 # e.g., number of audio sources in
n_ch_out = 5 # e.g., number of electrodes out
len_sig = fs * 120
# trf parameters
trf_start = -100e-3 # -100e-3
trf_stop = 300e-3
# make the signals with some correlations
x_in = rng.randn(n_ch_in, len_sig) + rng.randn(1, len_sig)
# continuously mix the signals
ideal = np.zeros((n_ch_in, int(round((trf_stop - trf_start) * fs)) + 1))
ideal[:, -int(round(trf_start * fs))] = 1.
x_in_filt = np.copy(x_in)
for ch in range(n_ch_in):
ba = butter(8, (ch + 1.) / (3. * n_ch_in), 'lowpass')
ideal[ch] = lfilter(ba[0], ba[1], ideal[ch])
x_in_filt[ch] = lfilter(ba[0], ba[1], x_in[ch])
w_in_out = rng.randn(n_ch_out, n_ch_in)
x_out = np.dot(w_in_out, x_in_filt) + 1 + rng.randn(n_ch_out, len_sig)
# this is what ideally would be recovered
ideal = w_in_out[..., np.newaxis] * ideal[np.newaxis]
###############################################################################
# Now solve for the TRF: w = (x * x.T + lam * reg) \ x * y
# First get XX^T and XY
x_xt, x_y, t_trf = trf_corr(x_in, x_out, fs, trf_start, trf_stop)
# Now do inverse with some regularization
w = trf_reg(x_xt, x_y, n_ch_in, 1e-1, reg_type='laplacian')
###############################################################################
# Plot the results
fig, axes = plt.subplots(n_ch_out, figsize=(3, 8))
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for ai in range(n_ch_out):
h = list()
for si in range(n_ch_in):
h.append(axes[ai].plot(
t_trf, ideal[ai, si], color=colors[si], alpha=0.5, linewidth=2)[0])
axes[ai].plot(
t_trf, w[ai, si], color=colors[si], linestyle='--', dashes=(10, 5))
axes[ai].set(xlim=t_trf[[0, -1]], ylim=[w.min(), w.max()],
ylabel='Channel %d' % ai)
if ai == n_ch_out // 2:
leg = axes[ai].legend(h, ['Signal 0', 'Signal 1'], loc='upper right')
leg.set_frame_on(False)
axes[2]
axes[-1].set(xlabel='Time (sec)')
mne.viz.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment