Created
June 28, 2017 14:20
-
-
Save djsutherland/4fb1122179cef52493db8ce1a3ce96ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This is a re-implementation of the experiments from [Kernel Two-Sample Hypothesis Testing Using Kernel Set Classification](https://arxiv.org/abs/1706.05612)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.metrics.pairwise import euclidean_distances\n", | |
"\n", | |
"def median_heuristic(X, Y, median_samples=1000):\n", | |
" sub = lambda feats, n: feats[np.random.choice(\n", | |
" feats.shape[0], min(feats.shape[0], n), replace=False)]\n", | |
" Z = np.r_[sub(X, median_samples // 2), sub(Y, median_samples // 2)]\n", | |
" D2 = euclidean_distances(Z, squared=True)\n", | |
" upper = D2[np.triu_indices_from(D2, k=1)]\n", | |
" kernel_width = np.median(upper, overwrite_input=True)\n", | |
" bandwidth = np.sqrt(kernel_width / 2)\n", | |
" return bandwidth" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Problem from the paper: $N(0, I)$ versus $N(0, v I)$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def gen_data(d, n=20, m=100, var=1.01):\n", | |
" X = np.random.randn(n, d)\n", | |
" Y = np.random.randn(m, d) * var\n", | |
" return X, Y" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## MMD version" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First, here's [some existing wrapper code](https://github.com/dougalsutherland/opt-mmd/blob/master/two_sample/mmd_test.py) to compute MMD tests with Shogun's fast permutations for getting thresholds." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import modshogun as sg\n", | |
"\n", | |
"import multiprocessing as mp\n", | |
"num_threads = mp.cpu_count()\n", | |
"sg.get_global_parallel().set_num_threads(num_threads)\n", | |
"\n", | |
"\n", | |
"def rbf_mmd_test(X, Y, bandwidth='median', null_samples=1001,\n", | |
" median_samples=1000, cache_size=32):\n", | |
" '''\n", | |
" Run an MMD test using a Gaussian kernel.\n", | |
" Parameters\n", | |
" ----------\n", | |
" X : row-instance feature array\n", | |
" Y : row-instance feature array\n", | |
" bandwidth : float or 'median'\n", | |
" The bandwidth of the RBF kernel (sigma).\n", | |
" If 'median', estimates the median pairwise distance in the\n", | |
" aggregate sample and uses that.\n", | |
" null_samples : int\n", | |
" How many times to sample from the null distribution.\n", | |
" median_samples : int\n", | |
" How many points to use for estimating the bandwidth.\n", | |
" Returns\n", | |
" -------\n", | |
" p_val : float\n", | |
" The obtained p value of the test.\n", | |
" stat : float\n", | |
" The test statistic.\n", | |
" null_samples : array of length null_samples\n", | |
" The samples from the null distribution.\n", | |
" bandwidth : float\n", | |
" The used kernel bandwidth\n", | |
" '''\n", | |
"\n", | |
" if bandwidth == 'median':\n", | |
" bandwidth = median_heuristic(X, Y, median_samples=median_samples)\n", | |
" kernel_width = 2 * bandwidth**2\n", | |
"\n", | |
" mmd = sg.QuadraticTimeMMD()\n", | |
" mmd.set_p(sg.RealFeatures(X.T.astype(np.float64)))\n", | |
" mmd.set_q(sg.RealFeatures(Y.T.astype(np.float64)))\n", | |
" mmd.set_kernel(sg.GaussianKernel(cache_size, kernel_width))\n", | |
"\n", | |
" mmd.set_num_null_samples(null_samples)\n", | |
" samps = mmd.sample_null()\n", | |
" stat = mmd.compute_statistic()\n", | |
"\n", | |
" p_val = np.mean(stat <= samps)\n", | |
" return p_val, stat, samps, bandwidth" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Type 2 error: don't reject when null is false." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 49.1 s, sys: 37.8 s, total: 1min 26s\n", | |
"Wall time: 3.87 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"mmd_ps = np.array([\n", | |
" rbf_mmd_test(*gen_data(5))[0] for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.92000000000000004" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(mmd_ps > .05).mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Type 1 error: reject when null is true." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 47.1 s, sys: 35.5 s, total: 1min 22s\n", | |
"Wall time: 3.65 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"mmd_ps_null = np.array([\n", | |
" rbf_mmd_test(*gen_data(5, var=1))[0] for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.070000000000000007" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(mmd_ps_null <= .05).mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This test is basically not doing anything, as claimed in the paper." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Kernel set classification method" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.metrics.pairwise import rbf_kernel\n", | |
"from sklearn.svm import OneClassSVM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def set_classification_test(X, Y, bw=1, nu=0.1, sub_size=7, n_subs=100):\n", | |
" gamma = 1 / (2 * bw**2)\n", | |
" KX = rbf_kernel(X, gamma=gamma)\n", | |
" KXY = rbf_kernel(X, Y, gamma=gamma)\n", | |
" \n", | |
" n = X.shape[0]\n", | |
" subs = np.vstack([\n", | |
" np.random.choice(n, sub_size, replace=False) for _ in range(n_subs)])\n", | |
" \n", | |
" K_subs = np.array([\n", | |
" [KX[np.ix_(si, sj)].mean() for sj in subs]\n", | |
" for si in subs])\n", | |
" K_test = np.array([\n", | |
" KXY[sub, :].mean() for sub in subs])[np.newaxis]\n", | |
" \n", | |
" svm = OneClassSVM(kernel='precomputed', nu=nu)\n", | |
" svm.fit(K_subs)\n", | |
" return svm.predict(K_test)[0] > 0" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Type 2 error: don't reject when null is false." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 44.8 s, sys: 414 ms, total: 45.2 s\n", | |
"Wall time: 44.7 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"svm_rejs = np.array([\n", | |
" set_classification_test(*gen_data(5)) for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.0" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(~svm_rejs).mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Type 1 error: reject when null is true." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 44.5 s, sys: 539 ms, total: 45.1 s\n", | |
"Wall time: 44.5 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"svm_rejs_null = np.array([\n", | |
" set_classification_test(*gen_data(5, var=1)) for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"svm_rejs_null.mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So this test _never_ rejected the null, unlike the results from the paper." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Theoretical best possible test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"If we know the exact generative model, the best test would compute the probability that $X$ is from either of the two distributions, $N(0, I)$ versus $N(0, v I)$; the same for $Y$, and then look at the probability that they match up (since they're indpendent)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The likelihood of $N(0, v I)$ is\n", | |
"$$\n", | |
"\\exp\\left( -\\frac{d}{2} \\log (2 \\pi v) - \\frac{1}{2 v} \\lVert x \\rVert^2 \\right)\n", | |
".$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"so, for a sample set $X$, the probability that it came from $N(0, I)$ as opposed to $N(0, v I)$ is\n", | |
"\\begin{align}\n", | |
"p_X\n", | |
"&=\n", | |
"\\frac{\n", | |
"\\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\sum_{i=1}^n \\frac{1}{2} \\lVert X_i \\rVert^2 \\right)\n", | |
"}{\n", | |
"\\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\sum_{i=1}^n \\frac{1}{2} \\lVert X_i \\rVert^2 \\right)\n", | |
"+ \\exp\\left( -\\frac{n d}{2} \\log (2 \\pi) - \\frac{n d}{2} \\log v - \\sum_{i=1}^n \\frac{1}{2 v} \\lVert X_i \\rVert^2 \\right)\n", | |
"}\n", | |
"\\\\ &=\n", | |
"\\frac{\n", | |
"\\exp\\left( - \\frac12 \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n", | |
"}{\n", | |
"\\exp\\left( - \\frac12 \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n", | |
"+ \\exp\\left( - \\frac{n d}{2} \\log v - \\frac1{2 v} \\sum_{i=1}^n \\lVert X_i \\rVert^2 \\right)\n", | |
"}\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So $p_X$ is the probability that sample $X$ came from $N(0, 1)$, $p_Y$ the probability that sample $Y$ did. Thus the probability that the two came from the _same_ distribution (either of the two) is $p_X p_Y + (1 - p_X) (1 - p_Y)$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's try that in the same way:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def prob_1(X, v=1.01):\n", | |
" n, d = X.shape\n", | |
" s = .5 * (X ** 2).sum()\n", | |
" l1 = np.exp(-s)\n", | |
" l2 = np.exp(-n * d / 2 * np.log(v) - s / v)\n", | |
" return l1 / (l1 + l2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def prob_same(X, Y, v=1.01):\n", | |
" pX = prob_1(X, v)\n", | |
" pY = prob_1(Y, v)\n", | |
" return pX * pY + (1 - pX) * (1 - pY)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"probs = np.array([\n", | |
" prob_same(*gen_data(5)) for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.50022106675412803, 0.49573053015762769, 0.5102717879843921)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(probs), np.min(probs), np.max(probs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So the distributions are basically theoretically indistinguishable at this sample size, even with $v = 1.01$ instead of $1 + 10^{-21}$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"probs_null = np.array([\n", | |
" prob_same(*gen_data(5, 1)) for _ in range(100)])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.50002290599779586, 0.49912412130653572, 0.50083149225372237)" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.mean(probs_null), np.min(probs_null), np.max(probs_null)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"They're also indistinguishable when they're the same distributions, unsurprisingly." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:shogun]", | |
"language": "python", | |
"name": "conda-env-shogun-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment