Created
October 10, 2018 20:04
-
-
Save larsoner/bace255395d1b47b6f88acfaaa203688 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Test boostrapping when boosting SNR. | |
""" | |
import numpy as np | |
from scipy import stats | |
import matplotlib.pyplot as plt | |
rng = np.random.RandomState(0) | |
n_epochs, n_time, n_labels = 40, 384, 10 | |
epochs = rng.randn(n_epochs, n_time, n_labels, n_labels) | |
n_trials = 10 # number of high-SNR samples to produce | |
n_samples = 1000 # number of bootstrap samples | |
resampling_inds = np.random.choice(n_trials, size=(n_samples, n_trials), | |
replace=True) | |
independent = list() | |
fa = list() # false | |
no_rep = np.reshape(np.arange(n_epochs), (n_trials, -1)) | |
with_rep = np.array([rng.permutation(n_epochs) for _ in range(n_trials)]) | |
assert no_rep.shape == (n_trials, n_epochs // n_trials) | |
assert with_rep.shape == (n_trials, n_epochs) | |
tests = (1, 2, 4, 10, 20, 25, 30, 35, 39, 40) | |
alphas = np.array((0.05, 0.05 / (n_time * n_labels * n_labels))) | |
use_bounds = np.reshape([100 * alphas / 2., 100 - 100 * alphas / 2.], | |
[2, 2]).T | |
use_thresh = alphas[:, np.newaxis, np.newaxis, np.newaxis] | |
boot_far = np.empty((len(alphas), len(tests))) | |
ttest_far = np.empty((len(alphas), len(tests))) | |
for ni, n_epochs_per_trial in enumerate(tests): | |
print('Testing %2d trial(s)/sample' % (n_epochs_per_trial,)) | |
if n_trials * n_epochs_per_trial <= n_epochs: | |
independent.append(1) # w/o replacement in SNR boosting | |
trials = epochs[no_rep[:, :n_epochs_per_trial]] | |
else: | |
independent.append(0) # w/replacement in SNR boosting | |
trials = epochs[with_rep[:, :n_epochs_per_trial]] | |
assert trials.shape == (n_trials, n_epochs_per_trial, | |
n_time, n_labels, n_labels) | |
trials = trials.mean(axis=1) # averaged across epochs | |
ttest_far[:, ni] = (stats.ttest_1samp(trials, 0, axis=0)[1] < | |
use_thresh).mean(-1).mean(-1).mean(-1) | |
# Our bootstrap samples are the resampled trials | |
boot = trials[resampling_inds] | |
assert boot.shape == (n_samples, n_trials, n_time, n_labels, n_labels) | |
# Our measure is simple: mean across trials | |
boot = boot.mean(axis=1) | |
for ai, bounds in enumerate(use_bounds): | |
ci = np.percentile(boot, bounds, axis=0) | |
rej = (ci[0] > 0) | (ci[1] < 0) | |
boot_far[ai, ni] = rej.mean(-1).mean(-1).mean(-1) | |
# Plot the results | |
fig, ax = plt.subplots(1, figsize=(4, 4)) | |
linestyles = (':', '-') | |
for ai, alpha in enumerate(alphas): | |
ax.plot(tests, boot_far[ai], color='r', ls=linestyles[ai], | |
label=u'bootstrap (α=%0.1g)' % alpha, zorder=4) | |
ax.plot(tests, ttest_far[ai], color='orange', ls=linestyles[ai], | |
label=u't-test (α=%0.1g)' % (alpha), zorder=4) | |
ax.axhline(0.05, color='k', ls='--', zorder=2) | |
ax.set(xlabel='N$_\mathrm{epochs/trial}$', ylabel='False alarm rate', | |
xlim=[tests[0], tests[-1]], ylim=[0, 1]) | |
ax.fill_between(tests, 0, 1, where=independent, color='b', alpha=0.1, | |
zorder=2, lw=0) | |
for key in ('top', 'right'): | |
ax.spines[key].set_visible(False) | |
ax.legend(loc='upper left', frameon=True, columnspacing=0.1, labelspacing=0.1, | |
fontsize=8, fancybox=True, handlelength=2.0) | |
fig.tight_layout() | |
fig.savefig('FAR.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment