Skip to content

Instantly share code, notes, and snippets.

@djsutherland
Last active March 30, 2025 03:30
Show Gist options
  • Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "d325051b-74dd-4c95-9254-14c3275c2aae",
"metadata": {},
"source": [
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n",
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$."
]
},
{
"cell_type": "markdown",
"id": "4986363a-a710-438d-af2a-49c8546261f3",
"metadata": {},
"source": [
"## Computation\n",
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n",
"\n",
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_b^2}\n",
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n",
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n",
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n",
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n",
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n",
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n",
",\\end{align*}\n",
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n",
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with."
]
},
{
"cell_type": "markdown",
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06",
"metadata": {},
"source": [
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_u^2}\n",
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n",
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n",
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n",
";\\end{align*}\n",
"this doesn't seem especially amenable to easy permutation in the same way as the previous one."
]
},
{
"cell_type": "markdown",
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1",
"metadata": {},
"source": [
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_U^2}\n",
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n",
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n",
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n",
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n",
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n",
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n",
"\\\\&\\qquad\n",
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n",
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n",
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n",
";\\end{align*}\n",
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n",
"The third term does, but isn't so bad."
]
},
{
"cell_type": "markdown",
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf",
"metadata": {},
"source": [
"## p-value\n",
"Per Theorem 2 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n",
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n",
"where $T_1$ is the actual data split\n",
"and $2$ through $w$ are uniformly random permutations.\n",
"\n",
"Their Theorem 2 establishes that a test which rejects when $T_1 > T^{\\left( \\lceil (1-\\alpha) w \\rceil \\right)}$, where $T^{(i)}$ is the $i$th order statistic, has level $\\alpha$.\n",
"We can find a (possibly conservative) $p$-value as the smallest value of $\\alpha$ for which this test can reject.\n",
"\n",
"Let $k = \\max \\{ k : T^{(k)} < T_1 \\}$. (If that set is empty, return $p$-value 1.)\n",
"Then, the smallest $\\alpha$ which works is\n",
"the smallest one satisfying $\\lceil (1-\\alpha) w \\rceil \\le k$,\n",
"which will have\n",
"$(1 - \\alpha) w = k$,\n",
"i.e. $\\alpha = (w - k) / w$.\n",
"Note that $w - k$ is exactly $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert$."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c",
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"PermutationResult = namedtuple(\n",
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n",
")\n",
"\n",
"\n",
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n",
" \"\"\"\n",
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n",
" ie [ K_XX K_XY ]\n",
" [ K_YX K_YY ]\n",
" n_X: number of entries in the first set\n",
" n_perm: total number of permutations, including the identity\n",
"\n",
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n",
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n",
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n",
" \"\"\"\n",
" K = joint_kernel = torch.as_tensor(joint_kernel)\n",
" device = K.device\n",
" dtype = K.dtype\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n",
" n_X = int(n_X)\n",
" n_Y = n - n_X\n",
" if n_X <= 0 or n_Y <= 0:\n",
" raise ValueError(\"need a positive number of samples from each\")\n",
"\n",
" if u_stat:\n",
" if n_X != n_Y:\n",
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n",
" w_X = 1\n",
" w_Y = -1\n",
" else:\n",
" w_X = 1 / n_X\n",
" w_Y = -1 / n_Y\n",
"\n",
" # construct permutations\n",
" # there probably should be a faster way to do this but, idk\n",
" perms = torch.stack(\n",
" [torch.arange(n, device=device)]\n",
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n",
" )\n",
" X_inds = perms[:, :n_X]\n",
" Y_inds = perms[:, n_X:]\n",
"\n",
" # set weights to w_X for things in X_inds, w_Y for others\n",
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n",
" ws.scatter_(1, X_inds, w_X)\n",
"\n",
" # the \"basic\" estimate; either the biased est or a constant times it\n",
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n",
"\n",
" if u_stat:\n",
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) - 2 k(X_i, Y_i)\n",
" # first two are just trace\n",
" # for the last one, we need to see which ones were lined up\n",
" # NOTE: take() makes an unnecessary copy if joint_kernel isn't already contiguous,\n",
" # but this generally shouldn't be a big deal\n",
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n",
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e149931-01e5-426a-9a90-bef389e14a4f",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.0660)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X = torch.randn(100, 3)\n",
"Y = torch.randn(100, 3) * 1.3\n",
"Z = torch.cat((X, Y))\n",
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n",
"\n",
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n",
"\n",
"plt.hist(res.permuted_estimates, bins=\"auto\")\n",
"plt.axvline(res.estimate, color=\"r\")\n",
"res.p_value"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment