Created
November 13, 2022 15:27
-
-
Save ahwillia/d32110db5727410107646c1b4db31001 to your computer and use it in GitHub Desktop.
A kernel two-sample test for equality of distributions (Gretton et al. 2012)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.spatial.distance import cdist, pdist | |
def mmd_two_sample_test(X, Y): | |
""" | |
Implements Gretton's test for equality of | |
distributions in high-dimensional settings | |
using concentration bounds on the maximum | |
mean discrepancy (MMD). This function uses | |
the unbiased estimator of the MMD (see | |
Lemma 6, Gretton et al., 2012) and upper | |
bounds the p-value using a Hoeffding | |
large-deviation bound (see Theorem 10, | |
Gretton et al., 2012). | |
The test considers two sets of observed | |
datapoints, X and Y, which are assumed to | |
be drawn i.i.d. from underlying probability | |
distributions P and Q. The null hypothesis | |
is that P = Q. | |
Note that this function assumes that the number | |
of samples from each distribution are equal. | |
Reference | |
--------- | |
Gretton et al. (2012). A Kernel Two-Sample Test. | |
Journal of Machine Learning Research 13: 723-773. | |
Parameters | |
---------- | |
X : ndarray (num_samples x num_features) | |
First set of observed samples, assumed to be | |
drawn from some unknown distribution P. | |
Y : ndarray (num_samples x num_features) | |
Second set of observed samples, assumed to be | |
drawn from some unknown distribution Q. | |
Returns | |
------- | |
pvalue : float | |
An upper bound on the probability of observing | |
an MMD distance greater than or equal to the | |
observed value, assuming that the null hypothesis | |
(i.e. that P = Q) is true. | |
""" | |
assert X.shape == Y.shape | |
m = X.shape[0] | |
# Compute pairwise distances | |
xd = pdist(X, metric="euclidean") | |
yd = pdist(Y, metric="euclidean") | |
xyd = cdist(X, Y, metric="euclidean").ravel() | |
# Set kernel bandwidth (Gretton et al. suggest to use | |
# the median distance). | |
sigma_sq = np.median( | |
np.concatenate((xd, yd, xyd)) | |
) ** 2 | |
# Compute unbiased MMD distance. | |
kxx = np.mean(np.exp(-(xd**2) / (2 * sigma_sq))) | |
kyy = np.mean(np.exp(-(yd**2) / (2 * sigma_sq))) | |
kxy = np.mean(np.exp(-(xyd**2) / (2 * sigma_sq))) | |
mmd_obs = kxx + kyy - 2 * kxy | |
# Apply theorem 10 to compute the p-value. | |
if mmd_obs < 0: | |
return 1.0 | |
else: | |
return np.exp( | |
-((mmd_obs ** 2) * (m // 2)) / 8 | |
) | |
if __name__ == "__main__": | |
# TEST THAT WE FAIL TO REJECT THE NULL | |
d = 10 | |
num_samples = 1000 | |
pvals = np.empty(100) | |
for seed in range(pvals.size): | |
# Draw random samples from equal distributions. | |
rs = np.random.RandomState(seed) | |
X = rs.randn(num_samples, d) | |
Y = rs.randn(num_samples, d) | |
pvals[seed] = mmd_two_sample_test(X, Y) | |
print("FIRST TEST -- NULL HYPOTHESIS TRUE") | |
print(f"{np.sum(pvals < 0.05)} / {pvals.size} tests reject the null.") | |
# TEST THAT WE REJECT THE NULL | |
for seed in range(pvals.size): | |
# Draw random samples from equal distributions. | |
rs = np.random.RandomState(seed) | |
X = rs.randn(num_samples, d) | |
Y = rs.randn(num_samples, d) + 1 | |
pvals[seed] = mmd_two_sample_test(X, Y) | |
print("SECOND TEST -- NULL HYPOTHESIS FALSE") | |
print(f"{np.sum(pvals < 0.05)} / {pvals.size} tests reject the null.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Link to paper: https://www.jmlr.org/papers/v13/gretton12a.html
Output should be: