Last active
February 7, 2021 18:31
-
-
Save ahwillia/661adb1703ba409a5630577155117946 to your computer and use it in GitHub Desktop.
Two-sample permutation test in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple implementation of a permutation test among two | |
independent samples. | |
""" | |
import numpy as np | |
from sklearn.utils.validation import check_random_state | |
from more_itertools import distinct_permutations | |
from scipy.stats import percentileofscore | |
from math import factorial | |
def permtest( | |
x, y, statistic="mean", max_samples=100000, | |
random_state=None): | |
""" | |
Conducts a permutation test between two independent | |
samples. | |
Parameters | |
---------- | |
x : ndarray | |
First set of datapoints. | |
y : ndarray | |
Second set of datapoints. | |
statistic : str or callable, optional | |
Function that takes in samples x and y and reports | |
statistic of interest. By default, "mean" reports | |
the difference of sample means. List of default | |
options are ("mean", "median"). | |
max_samples : int, optional. | |
Maximum number of label permutations to try. | |
random_state : np.random.RandomState, int, or None. | |
If specified, used to seed the random number | |
generator to shuffle the ordering of the | |
datapoints. | |
""" | |
# initialize random state and function of interest. | |
rs = check_random_state(random_state) | |
stat_func = _get_stat_func(statistic) | |
# Concatenate samples in random order. | |
xy = np.concatenate((x, y)) | |
rs.shuffle(xy) | |
# Create data labels (True if sample is in "x" and False if sample | |
# is in "y"), and randomly shuffle before generating permutations. | |
labels = np.zeros_like(xy, dtype="bool") | |
labels[:x.size] = True | |
rs.shuffle(labels) | |
# Number of distinct permutations. | |
n_perms = factorial(xy.size) // factorial(x.size) // factorial(y.size) | |
# Allocate space for computed statistics. | |
shuffled_stats = np.full(min(max_samples, n_perms), np.nan) | |
# Print coverage. | |
print("Computing {0:} / {1:.2e} ({2:2.2f}%) of label permutations: ".format( | |
shuffled_stats.size, n_perms, 100 * shuffled_stats.size / n_perms | |
)) | |
# Iterate over distinct permutations if we have sufficient coverage. | |
# Otherwise, yield random permutations | |
if (max_samples / n_perms) > 0.5: | |
print("Iterating over distinct permutations...") | |
itr = distinct_permutations(labels) | |
else: | |
print("Sampling random permutations...") | |
itr = _randperms(rs, labels) | |
# Iterate over distinct permutations. | |
for i, perm in enumerate(itr): | |
# End early. | |
if i >= shuffled_stats.size: | |
break | |
# Create shuffled stand-ins for x and y. | |
x_ = xy[np.asarray(perm)] | |
y_ = xy[~np.asarray(perm)] | |
# Compute statistic. | |
shuffled_stats[i] = stat_func(x_, y_) | |
# Compute a two-sided p-value. We take the smallest | |
# percentile and then multiply by two. | |
pval = 2 * 0.01 * min( | |
percentileofscore(shuffled_stats, stat_func(x, y)), | |
percentileofscore(shuffled_stats, stat_func(y, x)) | |
) | |
return pval | |
def _get_stat_func(name_or_func): | |
""" | |
Instantiates functions that compute default statistics of | |
interest. | |
""" | |
# If specified function is callable | |
if not isinstance(name_or_func, str): | |
if callable(name_or_func): | |
return name_or_func | |
else: | |
raise ValueError( | |
"`statistic` should be a string like ('mean', 'median')" | |
" or a function that takes in samples x, y and returns" | |
" the statistic of interest." | |
) | |
# Default functions. | |
if name_or_func == "mean": | |
return lambda x, y: np.mean(x) - np.mean(y) | |
elif name_or_func == "median": | |
return lambda x, y: np.median(x) - np.median(y) | |
else: | |
raise ValueError( | |
"Did not recognize statistic." | |
) | |
def _randperms(rs, labels): | |
perm = labels.copy() | |
while True: | |
rs.shuffle(perm) | |
yield perm | |
if __name__ == "__main__": | |
np.random.seed(123) | |
x = np.random.randn(100) | |
y = .5 + np.random.randn(100) | |
print( | |
f"p = {permtest(x, y, random_state=None)}" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment