Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created October 10, 2018 20:04
Show Gist options
  • Save larsoner/bace255395d1b47b6f88acfaaa203688 to your computer and use it in GitHub Desktop.
Save larsoner/bace255395d1b47b6f88acfaaa203688 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Test boostrapping when boosting SNR.
"""
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
rng = np.random.RandomState(0)
n_epochs, n_time, n_labels = 40, 384, 10
epochs = rng.randn(n_epochs, n_time, n_labels, n_labels)
n_trials = 10 # number of high-SNR samples to produce
n_samples = 1000 # number of bootstrap samples
resampling_inds = np.random.choice(n_trials, size=(n_samples, n_trials),
replace=True)
independent = list()
fa = list() # false
no_rep = np.reshape(np.arange(n_epochs), (n_trials, -1))
with_rep = np.array([rng.permutation(n_epochs) for _ in range(n_trials)])
assert no_rep.shape == (n_trials, n_epochs // n_trials)
assert with_rep.shape == (n_trials, n_epochs)
tests = (1, 2, 4, 10, 20, 25, 30, 35, 39, 40)
alphas = np.array((0.05, 0.05 / (n_time * n_labels * n_labels)))
use_bounds = np.reshape([100 * alphas / 2., 100 - 100 * alphas / 2.],
[2, 2]).T
use_thresh = alphas[:, np.newaxis, np.newaxis, np.newaxis]
boot_far = np.empty((len(alphas), len(tests)))
ttest_far = np.empty((len(alphas), len(tests)))
for ni, n_epochs_per_trial in enumerate(tests):
print('Testing %2d trial(s)/sample' % (n_epochs_per_trial,))
if n_trials * n_epochs_per_trial <= n_epochs:
independent.append(1) # w/o replacement in SNR boosting
trials = epochs[no_rep[:, :n_epochs_per_trial]]
else:
independent.append(0) # w/replacement in SNR boosting
trials = epochs[with_rep[:, :n_epochs_per_trial]]
assert trials.shape == (n_trials, n_epochs_per_trial,
n_time, n_labels, n_labels)
trials = trials.mean(axis=1) # averaged across epochs
ttest_far[:, ni] = (stats.ttest_1samp(trials, 0, axis=0)[1] <
use_thresh).mean(-1).mean(-1).mean(-1)
# Our bootstrap samples are the resampled trials
boot = trials[resampling_inds]
assert boot.shape == (n_samples, n_trials, n_time, n_labels, n_labels)
# Our measure is simple: mean across trials
boot = boot.mean(axis=1)
for ai, bounds in enumerate(use_bounds):
ci = np.percentile(boot, bounds, axis=0)
rej = (ci[0] > 0) | (ci[1] < 0)
boot_far[ai, ni] = rej.mean(-1).mean(-1).mean(-1)
# Plot the results
fig, ax = plt.subplots(1, figsize=(4, 4))
linestyles = (':', '-')
for ai, alpha in enumerate(alphas):
ax.plot(tests, boot_far[ai], color='r', ls=linestyles[ai],
label=u'bootstrap (α=%0.1g)' % alpha, zorder=4)
ax.plot(tests, ttest_far[ai], color='orange', ls=linestyles[ai],
label=u't-test (α=%0.1g)' % (alpha), zorder=4)
ax.axhline(0.05, color='k', ls='--', zorder=2)
ax.set(xlabel='N$_\mathrm{epochs/trial}$', ylabel='False alarm rate',
xlim=[tests[0], tests[-1]], ylim=[0, 1])
ax.fill_between(tests, 0, 1, where=independent, color='b', alpha=0.1,
zorder=2, lw=0)
for key in ('top', 'right'):
ax.spines[key].set_visible(False)
ax.legend(loc='upper left', frameon=True, columnspacing=0.1, labelspacing=0.1,
fontsize=8, fancybox=True, handlelength=2.0)
fig.tight_layout()
fig.savefig('FAR.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment