Last active
March 30, 2025 03:30
-
-
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "d325051b-74dd-4c95-9254-14c3275c2aae", | |
"metadata": {}, | |
"source": [ | |
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n", | |
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4986363a-a710-438d-af2a-49c8546261f3", | |
"metadata": {}, | |
"source": [ | |
"## Computation\n", | |
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n", | |
"\n", | |
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_b^2}\n", | |
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n", | |
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n", | |
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n", | |
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n", | |
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n", | |
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n", | |
",\\end{align*}\n", | |
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n", | |
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06", | |
"metadata": {}, | |
"source": [ | |
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_u^2}\n", | |
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n", | |
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n", | |
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n", | |
";\\end{align*}\n", | |
"this doesn't seem especially amenable to easy permutation in the same way as the previous one." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1", | |
"metadata": {}, | |
"source": [ | |
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_U^2}\n", | |
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n", | |
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n", | |
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n", | |
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n", | |
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n", | |
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n", | |
"\\\\&\\qquad\n", | |
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n", | |
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n", | |
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n", | |
";\\end{align*}\n", | |
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n", | |
"The third term does, but isn't so bad." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf", | |
"metadata": {}, | |
"source": [ | |
"## p-value\n", | |
"Per Theorem 2 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n", | |
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n", | |
"where $T_1$ is the actual data split\n", | |
"and $2$ through $w$ are uniformly random permutations.\n", | |
"\n", | |
"Their Theorem 2 establishes that a test which rejects when $T_1 > T^{\\left( \\lceil (1-\\alpha) w \\rceil \\right)}$, where $T^{(i)}$ is the $i$th order statistic, has level $\\alpha$.\n", | |
"We can find a (possibly conservative) $p$-value as the smallest value of $\\alpha$ for which this test can reject.\n", | |
"\n", | |
"Let $k = \\max \\{ k : T^{(k)} < T_1 \\}$. (If that set is empty, return $p$-value 1.)\n", | |
"Then, the smallest $\\alpha$ which works is\n", | |
"the smallest one satisfying $\\lceil (1-\\alpha) w \\rceil \\le k$,\n", | |
"which will have\n", | |
"$(1 - \\alpha) w = k$,\n", | |
"i.e. $\\alpha = (w - k) / w$.\n", | |
"Note that $w - k$ is exactly $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import namedtuple\n", | |
"\n", | |
"PermutationResult = namedtuple(\n", | |
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n", | |
")\n", | |
"\n", | |
"\n", | |
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n", | |
" \"\"\"\n", | |
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n", | |
" ie [ K_XX K_XY ]\n", | |
" [ K_YX K_YY ]\n", | |
" n_X: number of entries in the first set\n", | |
" n_perm: total number of permutations, including the identity\n", | |
"\n", | |
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n", | |
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n", | |
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n", | |
" \"\"\"\n", | |
" K = joint_kernel = torch.as_tensor(joint_kernel)\n", | |
" device = K.device\n", | |
" dtype = K.dtype\n", | |
"\n", | |
" n = K.shape[0]\n", | |
" if K.shape != (n, n):\n", | |
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n", | |
" n_X = int(n_X)\n", | |
" n_Y = n - n_X\n", | |
" if n_X <= 0 or n_Y <= 0:\n", | |
" raise ValueError(\"need a positive number of samples from each\")\n", | |
"\n", | |
" if u_stat:\n", | |
" if n_X != n_Y:\n", | |
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n", | |
" w_X = 1\n", | |
" w_Y = -1\n", | |
" else:\n", | |
" w_X = 1 / n_X\n", | |
" w_Y = -1 / n_Y\n", | |
"\n", | |
" # construct permutations\n", | |
" # there probably should be a faster way to do this but, idk\n", | |
" perms = torch.stack(\n", | |
" [torch.arange(n, device=device)]\n", | |
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n", | |
" )\n", | |
" X_inds = perms[:, :n_X]\n", | |
" Y_inds = perms[:, n_X:]\n", | |
"\n", | |
" # set weights to w_X for things in X_inds, w_Y for others\n", | |
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n", | |
" ws.scatter_(1, X_inds, w_X)\n", | |
"\n", | |
" # the \"basic\" estimate; either the biased est or a constant times it\n", | |
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n", | |
"\n", | |
" if u_stat:\n", | |
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) - 2 k(X_i, Y_i)\n", | |
" # first two are just trace\n", | |
" # for the last one, we need to see which ones were lined up\n", | |
" # NOTE: take() makes an unnecessary copy if joint_kernel isn't already contiguous,\n", | |
" # but this generally shouldn't be a big deal\n", | |
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n", | |
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n", | |
"\n", | |
" p_val = (ests >= ests[0]).float().mean()\n", | |
" return PermutationResult(ests[0], p_val, ests)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "5e149931-01e5-426a-9a90-bef389e14a4f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.0660)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlhUlEQVR4nO3dcVDU953/8ddGdAWEbWLiLtsQIT2iiWia6h2RNoX2hMaatDPk2hpy/mxzvdEj9kJyOQLH9cRcsyDpUa4lMaPjeGQyxMw1ptc5rxZy15C7Qy9o6NVDe7UjKmnccMmRXVQCUT+/Pyx7bhaNC7ufZfH5mPnOsJ/v5/v9vL+fbuTVD9/dr8MYYwQAAGDJNYkuAAAAXF0IHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsSkl0AR92/vx5vfXWW8rIyJDD4Uh0OQAA4AoYYzQ0NCSv16trrrn82saUCx9vvfWWsrOzE10GAACYgP7+ft14442X7TPlwkdGRoakC8VnZmYmuBoAGMfp05LXe+Hnt96S0tMTWw8wBQSDQWVnZ4d+j1/OlAsfY39qyczMJHwAmJpmzPi/nzMzCR/ARa7klgluOAUAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFUpiS4AuJSc6t1xOe+xhlVxOS8A4Mqw8gEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwKqrwcfbsWf3lX/6lcnNzlZqaqptvvllPPPGEzp8/H+pjjFFdXZ28Xq9SU1NVXFys3t7emBcOAACSU1ThY/PmzXr22WfV0tKiw4cPq7GxUU899ZR+8IMfhPo0NjaqqalJLS0t6u7ulsfjUUlJiYaGhmJePAAASD5RhY+9e/fqy1/+slatWqWcnBz9wR/8gUpLS7V//35JF1Y9mpubVVtbq7KyMuXn56u1tVVnzpxRW1tbXC4AAAAkl6jCx2c+8xn98z//s371q19Jkv7zP/9T//Zv/6YvfvGLkqS+vj75/X6VlpaGjnE6nSoqKlJXV9e45xwZGVEwGAzbAADA9BXVg+Uef/xxBQIBLVy4UDNmzNC5c+f05JNP6v7775ck+f1+SZLb7Q47zu126/jx4+Oes76+Xps2bZpI7QAAIAlFtfLx4osv6vnnn1dbW5veeOMNtba26rvf/a5aW1vD+jkcjrDXxpiItjE1NTUKBAKhrb+/P8pLAAAAySSqlY8///M/V3V1tVavXi1JWrx4sY4fP676+nqtXbtWHo9H0oUVkKysrNBxAwMDEashY5xOp5xO50TrBwAASSaqlY8zZ87ommvCD5kxY0boo7a5ubnyeDzq6OgI7R8dHVVnZ6cKCwtjUC4AAEh2Ua183HvvvXryySd10003adGiRerp6VFTU5MefPBBSRf+3FJZWSmfz6e8vDzl5eXJ5/MpLS1N5eXlcbkAAACQXKIKHz/4wQ/07W9/WxUVFRoYGJDX69W6dev0V3/1V6E+VVVVGh4eVkVFhQYHB1VQUKD29nZlZGTEvHgAAJB8HMYYk+giLhYMBuVyuRQIBJSZmZnocpBAOdW743LeYw2r4nJeXEVOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYFdWzXYDpIB5f285XtgPAlWPlAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVVOEjJydHDocjYnvooYckScYY1dXVyev1KjU1VcXFxert7Y1L4QAAIDlFFT66u7t18uTJ0NbR0SFJ+spXviJJamxsVFNTk1paWtTd3S2Px6OSkhINDQ3FvnIAAJCUogofN9xwgzweT2j7x3/8R33iE59QUVGRjDFqbm5WbW2tysrKlJ+fr9bWVp05c0ZtbW3xqh8AACSZCd/zMTo6queff14PPvigHA6H+vr65Pf7VVpaGurjdDpVVFSkrq6uS55nZGREwWAwbAMAANPXhMPHj370I7333nv6+te/Lkny+/2SJLfbHdbP7XaH9o2nvr5eLpcrtGVnZ0+0JAAAkAQmHD62b9+ulStXyuv1hrU7HI6w18aYiLaL1dTUKBAIhLb+/v6JlgQAAJJAykQOOn78uF555RXt2rUr1ObxeCRdWAHJysoKtQ8MDESshlzM6XTK6XROpAwAAJCEJrTysWPHDs2bN0+rVq0KteXm5srj8YQ+ASNduC+ks7NThYWFk68UAABMC1GvfJw/f147duzQ2rVrlZLyf4c7HA5VVlbK5/MpLy9PeXl58vl8SktLU3l5eUyLBgAAySvq8PHKK6/oxIkTevDBByP2VVVVaXh4WBUVFRocHFRBQYHa29uVkZERk2IBAEDycxhjTKKLuFgwGJTL5VIgEFBmZmaiy0EC5VTvTnQJV+xYw6qP7oTp4/Rpac6cCz+fOiWlpye2HmAKiOb3N892AQAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFgVdfj4zW9+oz/8wz/U3LlzlZaWpk9+8pM6cOBAaL8xRnV1dfJ6vUpNTVVxcbF6e3tjWjQAAEheUYWPwcFBffrTn9bMmTP1k5/8RIcOHdLf/M3f6GMf+1ioT2Njo5qamtTS0qLu7m55PB6VlJRoaGgo1rUDAIAklBJN582bNys7O1s7duwIteXk5IR+NsaoublZtbW1KisrkyS1trbK7Xarra1N69ati03VAAAgaUW18vHjH/9Yy5Yt01e+8hXNmzdPd9xxh7Zt2xba39fXJ7/fr9LS0lCb0+lUUVGRurq6xj3nyMiIgsFg2AYAAKavqMLH0aNHtWXLFuXl5emnP/2p1q9frz/90z/Vc889J0ny+/2SJLfbHXac2+0O7fuw+vp6uVyu0JadnT2R6wAAAEkiqvBx/vx5fepTn5LP59Mdd9yhdevW6Y//+I+1ZcuWsH4OhyPstTEmom1MTU2NAoFAaOvv74/yEgAAQDKJKnxkZWXptttuC2u79dZbdeLECUmSx+ORpIhVjoGBgYjVkDFOp1OZmZlhGwAAmL6iCh+f/vSn9d///d9hbb/61a80f/58SVJubq48Ho86OjpC+0dHR9XZ2anCwsIYlAsAAJJdVJ92eeSRR1RYWCifz6evfvWrev3117V161Zt3bpV0oU/t1RWVsrn8ykvL095eXny+XxKS0tTeXl5XC4AAAAkl6jCx+/+7u/q5ZdfVk1NjZ544gnl5uaqublZDzzwQKhPVVWVhoeHVVFRocHBQRUUFKi9vV0ZGRkxLx4AACQfhzHGJLqIiwWDQblcLgUCAe7/uMrlVO9OdAlX7FjDqkSXAJtOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVkUVPurq6uRwOMI2j8cT2m+MUV1dnbxer1JTU1VcXKze3t6YFw0AAJJX1CsfixYt0smTJ0PbwYMHQ/saGxvV1NSklpYWdXd3y+PxqKSkRENDQzEtGgAAJK+ow0dKSoo8Hk9ou+GGGyRdWPVobm5WbW2tysrKlJ+fr9bWVp05c0ZtbW0xLxwAACSnqMPHkSNH5PV6lZubq9WrV+vo0aOSpL6+Pvn9fpWWlob6Op1OFRUVqaur65LnGxkZUTAYDNsAAMD0lRJN54KCAj333HO65ZZb9Pbbb+s73/mOCgsL1dvbK7/fL0lyu91hx7jdbh0/fvyS56yvr9emTZsmUDqmkpzq3YkuAQCQJKJa+Vi5cqXuu+8+LV68WCtWrNDu3Rd+4bS2tob6OByOsGOMMRFtF6upqVEgEAht/f390ZQEAACSzKQ+apuenq7FixfryJEjoU+9jK2AjBkYGIhYDbmY0+lUZmZm2AYAAKavSYWPkZERHT58WFlZWcrNzZXH41FHR0do/+joqDo7O1VYWDjpQgEAwPQQ1T0fjz32mO69917ddNNNGhgY0He+8x0Fg0GtXbtWDodDlZWV8vl8ysvLU15ennw+n9LS0lReXh6v+oEpIV73vBxrWBWX8wJAIkUVPt58803df//9euedd3TDDTfozjvv1L59+zR//nxJUlVVlYaHh1VRUaHBwUEVFBSovb1dGRkZcSkeAAAkH4cxxiS6iIsFg0G5XC4FAgHu/0gifNolPlj5mKJOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVkugCAFxaTvXumJ/zWMOqmJ8TAKLBygcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKpJhY/6+no5HA5VVlaG2owxqqurk9frVWpqqoqLi9Xb2zvZOgEAwDQx4fDR3d2trVu3asmSJWHtjY2NampqUktLi7q7u+XxeFRSUqKhoaFJFwsAAJLfhMLHqVOn9MADD2jbtm269tprQ+3GGDU3N6u2tlZlZWXKz89Xa2urzpw5o7a2tpgVDQAAkteEwsdDDz2kVatWacWKFWHtfX198vv9Ki0tDbU5nU4VFRWpq6trcpUCAIBpISXaA3bu3Kk33nhD3d3dEfv8fr8kye12h7W73W4dP3583PONjIxoZGQk9DoYDEZbEgAASCJRrXz09/fr4Ycf1vPPP6/Zs2dfsp/D4Qh7bYyJaBtTX18vl8sV2rKzs6MpCQAAJJmowseBAwc0MDCgpUuXKiUlRSkpKers7NT3v/99paSkhFY8xlZAxgwMDESshoypqalRIBAIbf39/RO8FAAAkAyi+rPL7//+7+vgwYNhbd/4xje0cOFCPf7447r55pvl8XjU0dGhO+64Q5I0Ojqqzs5Obd68edxzOp1OOZ3OCZYPAACSTVThIyMjQ/n5+WFt6enpmjt3bqi9srJSPp9PeXl5ysvLk8/nU1pamsrLy2NXNQAASFpR33D6UaqqqjQ8PKyKigoNDg6qoKBA7e3tysjIiPVQAAAgCTmMMSbRRVwsGAzK5XIpEAgoMzMz0eXgCuVU7050CbhCxxpWJbqE5Hf6tDRnzoWfT52S0tMTWw8wBUTz+5tnuwAAAKsIHwAAwKqY3/OBqY0/jwAAEo2VDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVfMkYcJWJ1xfN8cwYAFeKlQ8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVVGFjy1btmjJkiXKzMxUZmamli9frp/85Ceh/cYY1dXVyev1KjU1VcXFxert7Y150QAAIHlFFT5uvPFGNTQ0aP/+/dq/f78+//nP68tf/nIoYDQ2NqqpqUktLS3q7u6Wx+NRSUmJhoaG4lI8AABIPlGFj3vvvVdf/OIXdcstt+iWW27Rk08+qTlz5mjfvn0yxqi5uVm1tbUqKytTfn6+WltbdebMGbW1tcWrfgAAkGQmfM/HuXPntHPnTp0+fVrLly9XX1+f/H6/SktLQ32cTqeKiorU1dV1yfOMjIwoGAyGbQAAYPpKifaAgwcPavny5Xr//fc1Z84cvfzyy7rttttCAcPtdof1d7vdOn78+CXPV19fr02bNkVbBoCrQE717ric91jDqricF8CViXrlY8GCBfr5z3+uffv26U/+5E+0du1aHTp0KLTf4XCE9TfGRLRdrKamRoFAILT19/dHWxIAAEgiUa98zJo1S7/zO78jSVq2bJm6u7v1t3/7t3r88cclSX6/X1lZWaH+AwMDEashF3M6nXI6ndGWAQAAktSkv+fDGKORkRHl5ubK4/Goo6MjtG90dFSdnZ0qLCyc7DAAAGCaiGrl4y/+4i+0cuVKZWdna2hoSDt37tSrr76qPXv2yOFwqLKyUj6fT3l5ecrLy5PP51NaWprKy8vjVT8AAEgyUYWPt99+W2vWrNHJkyflcrm0ZMkS7dmzRyUlJZKkqqoqDQ8Pq6KiQoODgyooKFB7e7syMjLiUjwAAEg+UYWP7du3X3a/w+FQXV2d6urqJlMTAACYxni2CwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrUqLpXF9fr127dumXv/ylUlNTVVhYqM2bN2vBggWhPsYYbdq0SVu3btXg4KAKCgr09NNPa9GiRTEvfrrLqd6d6BIAAIi5qFY+Ojs79dBDD2nfvn3q6OjQ2bNnVVpaqtOnT4f6NDY2qqmpSS0tLeru7pbH41FJSYmGhoZiXjwAAEg+Ua187NmzJ+z1jh07NG/ePB04cECf/exnZYxRc3OzamtrVVZWJklqbW2V2+1WW1ub1q1bF7vKAQBAUprUPR+BQECSdN1110mS+vr65Pf7VVpaGurjdDpVVFSkrq6ucc8xMjKiYDAYtgEAgOlrwuHDGKNHH31Un/nMZ5Sfny9J8vv9kiS32x3W1+12h/Z9WH19vVwuV2jLzs6eaEkAACAJTDh8bNiwQb/4xS/0wgsvROxzOBxhr40xEW1jampqFAgEQlt/f/9ESwIAAEkgqns+xnzrW9/Sj3/8Y7322mu68cYbQ+0ej0fShRWQrKysUPvAwEDEasgYp9Mpp9M5kTIAAEASimrlwxijDRs2aNeuXfqXf/kX5ebmhu3Pzc2Vx+NRR0dHqG10dFSdnZ0qLCyMTcUAACCpRbXy8dBDD6mtrU3/8A//oIyMjNB9HC6XS6mpqXI4HKqsrJTP51NeXp7y8vLk8/mUlpam8vLyuFwAAABILlGFjy1btkiSiouLw9p37Nihr3/965KkqqoqDQ8Pq6KiIvQlY+3t7crIyIhJwQCmJr4UD8CViip8GGM+so/D4VBdXZ3q6uomWhMAAJjGeLYLAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArIrq69UBYDqY7HNoUkff1+Hf/nzrt/doeNZsHWtYNfnCgKsEKx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqqjDx2uvvaZ7771XXq9XDodDP/rRj8L2G2NUV1cnr9er1NRUFRcXq7e3N1b1AgCAJBd1+Dh9+rRuv/12tbS0jLu/sbFRTU1NamlpUXd3tzwej0pKSjQ0NDTpYgEAQPJLifaAlStXauXKlePuM8aoublZtbW1KisrkyS1trbK7Xarra1N69atm1y1AAAg6cX0no++vj75/X6VlpaG2pxOp4qKitTV1TXuMSMjIwoGg2EbAACYvmIaPvx+vyTJ7XaHtbvd7tC+D6uvr5fL5Qpt2dnZsSwJAABMMXH5tIvD4Qh7bYyJaBtTU1OjQCAQ2vr7++NREgAAmCKivufjcjwej6QLKyBZWVmh9oGBgYjVkDFOp1NOpzOWZQAAgCkspisfubm58ng86ujoCLWNjo6qs7NThYWFsRwKAAAkqahXPk6dOqVf//rXodd9fX36+c9/ruuuu0433XSTKisr5fP5lJeXp7y8PPl8PqWlpam8vDymhQMAgOQUdfjYv3+/Pve5z4VeP/roo5KktWvX6u/+7u9UVVWl4eFhVVRUaHBwUAUFBWpvb1dGRkbsqgYAAEkr6vBRXFwsY8wl9zscDtXV1amurm4ydQEAgGmKZ7sAAACrYvppFwC4WuVU7050CVfsWMOqRJeAqxwrHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jhFACuMvG6OZYbWXGlWPkAAABWET4AAIBVhA8AAGAV93zEQDJ9uRAAAInGygcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq/ieDwBATMTjO494Xsz0xMoHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqrg9WO6ZZ57RU089pZMnT2rRokVqbm7WXXfdFa/hrlg8HnwEAEguV/vvgkQ/sC8uKx8vvviiKisrVVtbq56eHt11111auXKlTpw4EY/hAABAEolL+GhqatIf/dEf6Zvf/KZuvfVWNTc3Kzs7W1u2bInHcAAAIInE/M8uo6OjOnDggKqrq8PaS0tL1dXVFdF/ZGREIyMjodeBQECSFAwGY12aJOn8yJm4nBfA1ePc6Psa+xfq3MgZnTfnE1rPdMbvgviIx7yOndMY85F9Yx4+3nnnHZ07d05utzus3e12y+/3R/Svr6/Xpk2bItqzs7NjXRoAxIxr7Idn/l8iy5j2XM2JrmB6iue8Dg0NyeVyXbZP3G44dTgcYa+NMRFtklRTU6NHH3009Pr8+fP63//9X82dO3fc/lNFMBhUdna2+vv7lZmZmehyEo75iMSchGM+IjEn4ZiPSMk0J8YYDQ0Nyev1fmTfmIeP66+/XjNmzIhY5RgYGIhYDZEkp9Mpp9MZ1vaxj30s1mXFTWZm5pR/Q9jEfERiTsIxH5GYk3DMR6RkmZOPWvEYE/MbTmfNmqWlS5eqo6MjrL2jo0OFhYWxHg4AACSZuPzZ5dFHH9WaNWu0bNkyLV++XFu3btWJEye0fv36eAwHAACSSFzCx9e+9jW9++67euKJJ3Ty5Enl5+frn/7pnzR//vx4DJcQTqdTGzdujPiT0dWK+YjEnIRjPiIxJ+GYj0jTdU4c5ko+EwMAABAjPNsFAABYRfgAAABWET4AAIBVhA8AAGAV4eO3BgcHtWbNGrlcLrlcLq1Zs0bvvffeZY8xxqiurk5er1epqakqLi5Wb29vWJ+tW7equLhYmZmZcjgc455zImPHW7zmY2RkRN/61rd0/fXXKz09XV/60pf05ptvhvXJycmRw+EI2z78rCAbnnnmGeXm5mr27NlaunSp/vVf//Wy/Ts7O7V06VLNnj1bN998s5599tmIPi+99JJuu+02OZ1O3XbbbXr55ZcnPa4tiZiPurq6iPeCx+OJ6XVNRqznpLe3V/fdd1/ov4Hm5uaYjGtLIubjanuPbNu2TXfddZeuvfZaXXvttVqxYoVef/31SY9rnYExxpi7777b5Ofnm66uLtPV1WXy8/PNPffcc9ljGhoaTEZGhnnppZfMwYMHzde+9jWTlZVlgsFgqM/3vvc9U19fb+rr640kMzg4GJOx4y1e87F+/Xrz8Y9/3HR0dJg33njDfO5znzO33367OXv2bKjP/PnzzRNPPGFOnjwZ2oaGhuJ2rePZuXOnmTlzptm2bZs5dOiQefjhh016ero5fvz4uP2PHj1q0tLSzMMPP2wOHTpktm3bZmbOnGl++MMfhvp0dXWZGTNmGJ/PZw4fPmx8Pp9JSUkx+/btm/C4tiRqPjZu3GgWLVoU9l4YGBiI+/VeiXjMyeuvv24ee+wx88ILLxiPx2O+973vTXpcWxI1H1fbe6S8vNw8/fTTpqenxxw+fNh84xvfMC6Xy7z55psTHjcRCB/GmEOHDhlJYf/o7d2710gyv/zlL8c95vz588bj8ZiGhoZQ2/vvv29cLpd59tlnI/r/7Gc/Gzd8TGTseIvXfLz33ntm5syZZufOnaE+v/nNb8w111xj9uzZE2qbP3/+uP/I2PR7v/d7Zv369WFtCxcuNNXV1eP2r6qqMgsXLgxrW7dunbnzzjtDr7/61a+au+++O6zPF77wBbN69eoJj2tLouZj48aN5vbbb59k9fERjzm52KX+O7ia3iMXu9R8XM3vEWOMOXv2rMnIyDCtra0THjcR+LOLpL1798rlcqmgoCDUduedd8rlcqmrq2vcY/r6+uT3+1VaWhpqczqdKioquuQxsRo73uI1HwcOHNAHH3wQ1sfr9So/Pz/ivJs3b9bcuXP1yU9+Uk8++aRGR0djeYmXNTo6qgMHDoTVKUmlpaWXvP69e/dG9P/CF76g/fv364MPPrhsn7FzTmRcGxI1H2OOHDkir9er3NxcrV69WkePHp3sJU1avOYkHuPakKj5GHM1v0fOnDmjDz74QNddd92Ex00Ewockv9+vefPmRbTPmzcv4gF5Fx8jKeJheW63+5LHxGrseIvXfPj9fs2aNUvXXnvtJftI0sMPP6ydO3fqZz/7mTZs2KDm5mZVVFRM6pqi8c477+jcuXNR/W/r9/vH7X/27Fm98847l+0zds6JjGtDouZDkgoKCvTcc8/ppz/9qbZt2ya/36/CwkK9++67sbi0CYvXnMRjXBsSNR8S75Hq6mp9/OMf14oVKyY8biJM6/Ax3o1IH972798vSXI4HBHHG2PGbb/Yh/dfyTEfdY6JnuejTNX5+HCfRx55REVFRVqyZIm++c1v6tlnn9X27dut/2MS7bWM1//D7Vdyzli8p+IhEfOxcuVK3XfffVq8eLFWrFih3bt3S5JaW1sndhExFo85ice4tiRiPq7m90hjY6NeeOEF7dq1S7Nnz57UuLbF5dkuU8WGDRu0evXqy/bJycnRL37xC7399tsR+/7nf/4nIj2OGbub2u/3KysrK9Q+MDBwyWMudZ5ox56oRM+Hx+PR6OioBgcHw1Y/BgYGLvvE4zvvvFOS9Otf/1pz5869bP2xcP3112vGjBkR/y/hcv/bejyecfunpKSEar5Un7FzTmRcGxI1H+NJT0/X4sWLdeTIkYlcSszEa07iMa4NiZqP8Vwt75Hvfve78vl8euWVV7RkyZJJjZsI03rl4/rrr9fChQsvu82ePVvLly9XIBAI+7jSf/zHfygQCFzyl2Jubq48Ho86OjpCbaOjo+rs7LzsL9IPm8jYE5Xo+Vi6dKlmzpwZ1ufkyZP6r//6r8tea09PjySFhZp4mjVrlpYuXRpWpyR1dHRcss7ly5dH9G9vb9eyZcs0c+bMy/YZO+dExrUhUfMxnpGRER0+fNjae+FS4jUn8RjXhkTNx3iuhvfIU089pb/+67/Wnj17tGzZskmPmxBWb2+dwu6++26zZMkSs3fvXrN3716zePHiiI+WLliwwOzatSv0uqGhwbhcLrNr1y5z8OBBc//990d8tPTkyZOmp6fHbNu2zUgyr732munp6THvvvtuVGPbFq/5WL9+vbnxxhvNK6+8Yt544w3z+c9/Puyjtl1dXaapqcn09PSYo0ePmhdffNF4vV7zpS99yc6F/9bYR9W2b99uDh06ZCorK016ero5duyYMcaY6upqs2bNmlD/sY/IPfLII+bQoUNm+/btER+R+/d//3czY8YM09DQYA4fPmwaGhou+VHbS42bKImajz/7sz8zr776qjl69KjZt2+fueeee0xGRkbC58OY+MzJyMiI6enpMT09PSYrK8s89thjpqenxxw5cuSKx02URM3H1fYe2bx5s5k1a5b54Q9/eMmvI5iq75GLET5+69133zUPPPCAycjIMBkZGeaBBx6I+FisJLNjx47Q6/Pnz5uNGzcaj8djnE6n+exnP2sOHjwYdszGjRuNpIjt4vNcydi2xWs+hoeHzYYNG8x1111nUlNTzT333GNOnDgR2n/gwAFTUFBgXC6XmT17tlmwYIHZuHGjOX36dDwvd1xPP/20mT9/vpk1a5b51Kc+ZTo7O0P71q5da4qKisL6v/rqq+aOO+4ws2bNMjk5OWbLli0R5/z7v/97s2DBAjNz5kyzcOFC89JLL0U1biIlYj7Gvitm5syZxuv1mrKyMtPb2xuX65uIWM9JX1/fuP9efPg8V8t75Erm42p7j8yfP3/cOdm4ceMVjzsVOIz57d0sAAAAFkzrez4AAMDUQ/gAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABg1f8HemSwncbCvp0AAAAASUVORK5CYII=", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"X = torch.randn(100, 3)\n", | |
"Y = torch.randn(100, 3) * 1.3\n", | |
"Z = torch.cat((X, Y))\n", | |
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n", | |
"\n", | |
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n", | |
"\n", | |
"plt.hist(res.permuted_estimates, bins=\"auto\")\n", | |
"plt.axvline(res.estimate, color=\"r\")\n", | |
"res.p_value" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment