Skip to content

Instantly share code, notes, and snippets.

@djsutherland
Last active March 30, 2025 03:30
Show Gist options
  • Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "d325051b-74dd-4c95-9254-14c3275c2aae",
"metadata": {},
"source": [
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n",
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$."
]
},
{
"cell_type": "markdown",
"id": "4986363a-a710-438d-af2a-49c8546261f3",
"metadata": {},
"source": [
"## Computation\n",
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n",
"\n",
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_b^2}\n",
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n",
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n",
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n",
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n",
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n",
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n",
",\\end{align*}\n",
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n",
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with."
]
},
{
"cell_type": "markdown",
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06",
"metadata": {},
"source": [
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_u^2}\n",
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n",
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n",
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n",
";\\end{align*}\n",
"this doesn't seem especially amenable to easy permutation in the same way as the previous one."
]
},
{
"cell_type": "markdown",
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1",
"metadata": {},
"source": [
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_U^2}\n",
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n",
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n",
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n",
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n",
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n",
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n",
"\\\\&\\qquad\n",
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n",
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n",
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n",
";\\end{align*}\n",
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n",
"The third term does, but isn't so bad."
]
},
{
"cell_type": "markdown",
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf",
"metadata": {},
"source": [
"## p-value\n",
"Per Theorem 2 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n",
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n",
"where $T_1$ is the actual data split\n",
"and $2$ through $w$ are uniformly random permutations.\n",
"\n",
"Their Theorem 2 establishes that a test which rejects when $T_1 > T^{\\left( \\lceil (1-\\alpha) w \\rceil \\right)}$, where $T^{(i)}$ is the $i$th order statistic, has level $\\alpha$.\n",
"We can find a (possibly conservative) $p$-value as the smallest value of $\\alpha$ for which this test can reject.\n",
"\n",
"Let $k = \\max \\{ k : T^{(k)} < T_1 \\}$. (If that set is empty, return $p$-value 1.)\n",
"Then, the smallest $\\alpha$ which works is\n",
"the smallest one satisfying $\\lceil (1-\\alpha) w \\rceil \\le k$,\n",
"which will have\n",
"$(1 - \\alpha) w = k$,\n",
"i.e. $\\alpha = (w - k) / w$.\n",
"Note that $w - k$ is exactly $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert$."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c",
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"PermutationResult = namedtuple(\n",
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n",
")\n",
"\n",
"\n",
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n",
" \"\"\"\n",
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n",
" ie [ K_XX K_XY ]\n",
" [ K_YX K_YY ]\n",
" n_X: number of entries in the first set\n",
" n_perm: total number of permutations, including the identity\n",
"\n",
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n",
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n",
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n",
" \"\"\"\n",
" K = joint_kernel = torch.as_tensor(joint_kernel)\n",
" device = K.device\n",
" dtype = K.dtype\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n",
" n_X = int(n_X)\n",
" n_Y = n - n_X\n",
" if n_X <= 0 or n_Y <= 0:\n",
" raise ValueError(\"need a positive number of samples from each\")\n",
"\n",
" if u_stat:\n",
" if n_X != n_Y:\n",
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n",
" w_X = 1\n",
" w_Y = -1\n",
" else:\n",
" w_X = 1 / n_X\n",
" w_Y = -1 / n_Y\n",
"\n",
" # construct permutations\n",
" # there probably should be a faster way to do this but, idk\n",
" perms = torch.stack(\n",
" [torch.arange(n, device=device)]\n",
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n",
" )\n",
" X_inds = perms[:, :n_X]\n",
" Y_inds = perms[:, n_X:]\n",
"\n",
" # set weights to w_X for things in X_inds, w_Y for others\n",
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n",
" ws.scatter_(1, X_inds, w_X)\n",
"\n",
" # the \"basic\" estimate; either the biased est or a constant times it\n",
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n",
"\n",
" if u_stat:\n",
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) - 2 k(X_i, Y_i)\n",
" # first two are just trace\n",
" # for the last one, we need to see which ones were lined up\n",
" # NOTE: take() makes an unnecessary copy if joint_kernel isn't already contiguous,\n",
" # but this generally shouldn't be a big deal\n",
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n",
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e149931-01e5-426a-9a90-bef389e14a4f",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.0660)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlhUlEQVR4nO3dcVDU953/8ddGdAWEbWLiLtsQIT2iiWia6h2RNoX2hMaatDPk2hpy/mxzvdEj9kJyOQLH9cRcsyDpUa4lMaPjeGQyxMw1ptc5rxZy15C7Qy9o6NVDe7UjKmnccMmRXVQCUT+/Pyx7bhaNC7ufZfH5mPnOsJ/v5/v9vL+fbuTVD9/dr8MYYwQAAGDJNYkuAAAAXF0IHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsSkl0AR92/vx5vfXWW8rIyJDD4Uh0OQAA4AoYYzQ0NCSv16trrrn82saUCx9vvfWWsrOzE10GAACYgP7+ft14442X7TPlwkdGRoakC8VnZmYmuBoAGMfp05LXe+Hnt96S0tMTWw8wBQSDQWVnZ4d+j1/OlAsfY39qyczMJHwAmJpmzPi/nzMzCR/ARa7klgluOAUAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFUpiS4AuJSc6t1xOe+xhlVxOS8A4Mqw8gEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwKqrwcfbsWf3lX/6lcnNzlZqaqptvvllPPPGEzp8/H+pjjFFdXZ28Xq9SU1NVXFys3t7emBcOAACSU1ThY/PmzXr22WfV0tKiw4cPq7GxUU899ZR+8IMfhPo0NjaqqalJLS0t6u7ulsfjUUlJiYaGhmJePAAASD5RhY+9e/fqy1/+slatWqWcnBz9wR/8gUpLS7V//35JF1Y9mpubVVtbq7KyMuXn56u1tVVnzpxRW1tbXC4AAAAkl6jCx2c+8xn98z//s371q19Jkv7zP/9T//Zv/6YvfvGLkqS+vj75/X6VlpaGjnE6nSoqKlJXV9e45xwZGVEwGAzbAADA9BXVg+Uef/xxBQIBLVy4UDNmzNC5c+f05JNP6v7775ck+f1+SZLb7Q47zu126/jx4+Oes76+Xps2bZpI7QAAIAlFtfLx4osv6vnnn1dbW5veeOMNtba26rvf/a5aW1vD+jkcjrDXxpiItjE1NTUKBAKhrb+/P8pLAAAAySSqlY8///M/V3V1tVavXi1JWrx4sY4fP676+nqtXbtWHo9H0oUVkKysrNBxAwMDEashY5xOp5xO50TrBwAASSaqlY8zZ87ommvCD5kxY0boo7a5ubnyeDzq6OgI7R8dHVVnZ6cKCwtjUC4AAEh2Ua183HvvvXryySd10003adGiRerp6VFTU5MefPBBSRf+3FJZWSmfz6e8vDzl5eXJ5/MpLS1N5eXlcbkAAACQXKIKHz/4wQ/07W9/WxUVFRoYGJDX69W6dev0V3/1V6E+VVVVGh4eVkVFhQYHB1VQUKD29nZlZGTEvHgAAJB8HMYYk+giLhYMBuVyuRQIBJSZmZnocpBAOdW743LeYw2r4nJeXEVOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYFdWzXYDpIB5f285XtgPAlWPlAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVVOEjJydHDocjYnvooYckScYY1dXVyev1KjU1VcXFxert7Y1L4QAAIDlFFT66u7t18uTJ0NbR0SFJ+spXviJJamxsVFNTk1paWtTd3S2Px6OSkhINDQ3FvnIAAJCUogofN9xwgzweT2j7x3/8R33iE59QUVGRjDFqbm5WbW2tysrKlJ+fr9bWVp05c0ZtbW3xqh8AACSZCd/zMTo6queff14PPvigHA6H+vr65Pf7VVpaGurjdDpVVFSkrq6uS55nZGREwWAwbAMAANPXhMPHj370I7333nv6+te/Lkny+/2SJLfbHdbP7XaH9o2nvr5eLpcrtGVnZ0+0JAAAkAQmHD62b9+ulStXyuv1hrU7HI6w18aYiLaL1dTUKBAIhLb+/v6JlgQAAJJAykQOOn78uF555RXt2rUr1ObxeCRdWAHJysoKtQ8MDESshlzM6XTK6XROpAwAAJCEJrTysWPHDs2bN0+rVq0KteXm5srj8YQ+ASNduC+ks7NThYWFk68UAABMC1GvfJw/f147duzQ2rVrlZLyf4c7HA5VVlbK5/MpLy9PeXl58vl8SktLU3l5eUyLBgAAySvq8PHKK6/oxIkTevDBByP2VVVVaXh4WBUVFRocHFRBQYHa29uVkZERk2IBAEDycxhjTKKLuFgwGJTL5VIgEFBmZmaiy0EC5VTvTnQJV+xYw6qP7oTp4/Rpac6cCz+fOiWlpye2HmAKiOb3N892AQAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFgVdfj4zW9+oz/8wz/U3LlzlZaWpk9+8pM6cOBAaL8xRnV1dfJ6vUpNTVVxcbF6e3tjWjQAAEheUYWPwcFBffrTn9bMmTP1k5/8RIcOHdLf/M3f6GMf+1ioT2Njo5qamtTS0qLu7m55PB6VlJRoaGgo1rUDAIAklBJN582bNys7O1s7duwIteXk5IR+NsaoublZtbW1KisrkyS1trbK7Xarra1N69ati03VAAAgaUW18vHjH/9Yy5Yt01e+8hXNmzdPd9xxh7Zt2xba39fXJ7/fr9LS0lCb0+lUUVGRurq6xj3nyMiIgsFg2AYAAKavqMLH0aNHtWXLFuXl5emnP/2p1q9frz/90z/Vc889J0ny+/2SJLfbHXac2+0O7fuw+vp6uVyu0JadnT2R6wAAAEkiqvBx/vx5fepTn5LP59Mdd9yhdevW6Y//+I+1ZcuWsH4OhyPstTEmom1MTU2NAoFAaOvv74/yEgAAQDKJKnxkZWXptttuC2u79dZbdeLECUmSx+ORpIhVjoGBgYjVkDFOp1OZmZlhGwAAmL6iCh+f/vSn9d///d9hbb/61a80f/58SVJubq48Ho86OjpC+0dHR9XZ2anCwsIYlAsAAJJdVJ92eeSRR1RYWCifz6evfvWrev3117V161Zt3bpV0oU/t1RWVsrn8ykvL095eXny+XxKS0tTeXl5XC4AAAAkl6jCx+/+7u/q5ZdfVk1NjZ544gnl5uaqublZDzzwQKhPVVWVhoeHVVFRocHBQRUUFKi9vV0ZGRkxLx4AACQfhzHGJLqIiwWDQblcLgUCAe7/uMrlVO9OdAlX7FjDqkSXAJtOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVkUVPurq6uRwOMI2j8cT2m+MUV1dnbxer1JTU1VcXKze3t6YFw0AAJJX1CsfixYt0smTJ0PbwYMHQ/saGxvV1NSklpYWdXd3y+PxqKSkRENDQzEtGgAAJK+ow0dKSoo8Hk9ou+GGGyRdWPVobm5WbW2tysrKlJ+fr9bWVp05c0ZtbW0xLxwAACSnqMPHkSNH5PV6lZubq9WrV+vo0aOSpL6+Pvn9fpWWlob6Op1OFRUVqaur65LnGxkZUTAYDNsAAMD0lRJN54KCAj333HO65ZZb9Pbbb+s73/mOCgsL1dvbK7/fL0lyu91hx7jdbh0/fvyS56yvr9emTZsmUDqmkpzq3YkuAQCQJKJa+Vi5cqXuu+8+LV68WCtWrNDu3Rd+4bS2tob6OByOsGOMMRFtF6upqVEgEAht/f390ZQEAACSzKQ+apuenq7FixfryJEjoU+9jK2AjBkYGIhYDbmY0+lUZmZm2AYAAKavSYWPkZERHT58WFlZWcrNzZXH41FHR0do/+joqDo7O1VYWDjpQgEAwPQQ1T0fjz32mO69917ddNNNGhgY0He+8x0Fg0GtXbtWDodDlZWV8vl8ysvLU15ennw+n9LS0lReXh6v+oEpIV73vBxrWBWX8wJAIkUVPt58803df//9euedd3TDDTfozjvv1L59+zR//nxJUlVVlYaHh1VRUaHBwUEVFBSovb1dGRkZcSkeAAAkH4cxxiS6iIsFg0G5XC4FAgHu/0gifNolPlj5mKJOn5bmzLnw86lTUnp6YusBpoBofn/zbBcAAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFiVkugCAFxaTvXumJ/zWMOqmJ8TAKLBygcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwKpJhY/6+no5HA5VVlaG2owxqqurk9frVWpqqoqLi9Xb2zvZOgEAwDQx4fDR3d2trVu3asmSJWHtjY2NampqUktLi7q7u+XxeFRSUqKhoaFJFwsAAJLfhMLHqVOn9MADD2jbtm269tprQ+3GGDU3N6u2tlZlZWXKz89Xa2urzpw5o7a2tpgVDQAAkteEwsdDDz2kVatWacWKFWHtfX198vv9Ki0tDbU5nU4VFRWpq6trcpUCAIBpISXaA3bu3Kk33nhD3d3dEfv8fr8kye12h7W73W4dP3583PONjIxoZGQk9DoYDEZbEgAASCJRrXz09/fr4Ycf1vPPP6/Zs2dfsp/D4Qh7bYyJaBtTX18vl8sV2rKzs6MpCQAAJJmowseBAwc0MDCgpUuXKiUlRSkpKers7NT3v/99paSkhFY8xlZAxgwMDESshoypqalRIBAIbf39/RO8FAAAkAyi+rPL7//+7+vgwYNhbd/4xje0cOFCPf7447r55pvl8XjU0dGhO+64Q5I0Ojqqzs5Obd68edxzOp1OOZ3OCZYPAACSTVThIyMjQ/n5+WFt6enpmjt3bqi9srJSPp9PeXl5ysvLk8/nU1pamsrLy2NXNQAASFpR33D6UaqqqjQ8PKyKigoNDg6qoKBA7e3tysjIiPVQAAAgCTmMMSbRRVwsGAzK5XIpEAgoMzMz0eXgCuVU7050CbhCxxpWJbqE5Hf6tDRnzoWfT52S0tMTWw8wBUTz+5tnuwAAAKsIHwAAwKqY3/OBqY0/jwAAEo2VDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVfMkYcJWJ1xfN8cwYAFeKlQ8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVVGFjy1btmjJkiXKzMxUZmamli9frp/85Ceh/cYY1dXVyev1KjU1VcXFxert7Y150QAAIHlFFT5uvPFGNTQ0aP/+/dq/f78+//nP68tf/nIoYDQ2NqqpqUktLS3q7u6Wx+NRSUmJhoaG4lI8AABIPlGFj3vvvVdf/OIXdcstt+iWW27Rk08+qTlz5mjfvn0yxqi5uVm1tbUqKytTfn6+WltbdebMGbW1tcWrfgAAkGQmfM/HuXPntHPnTp0+fVrLly9XX1+f/H6/SktLQ32cTqeKiorU1dV1yfOMjIwoGAyGbQAAYPpKifaAgwcPavny5Xr//fc1Z84cvfzyy7rttttCAcPtdof1d7vdOn78+CXPV19fr02bNkVbBoCrQE717ric91jDqricF8CViXrlY8GCBfr5z3+uffv26U/+5E+0du1aHTp0KLTf4XCE9TfGRLRdrKamRoFAILT19/dHWxIAAEgiUa98zJo1S7/zO78jSVq2bJm6u7v1t3/7t3r88cclSX6/X1lZWaH+AwMDEashF3M6nXI6ndGWAQAAktSkv+fDGKORkRHl5ubK4/Goo6MjtG90dFSdnZ0qLCyc7DAAAGCaiGrl4y/+4i+0cuVKZWdna2hoSDt37tSrr76qPXv2yOFwqLKyUj6fT3l5ecrLy5PP51NaWprKy8vjVT8AAEgyUYWPt99+W2vWrNHJkyflcrm0ZMkS7dmzRyUlJZKkqqoqDQ8Pq6KiQoODgyooKFB7e7syMjLiUjwAAEg+UYWP7du3X3a/w+FQXV2d6urqJlMTAACYxni2CwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrUqLpXF9fr127dumXv/ylUlNTVVhYqM2bN2vBggWhPsYYbdq0SVu3btXg4KAKCgr09NNPa9GiRTEvfrrLqd6d6BIAAIi5qFY+Ojs79dBDD2nfvn3q6OjQ2bNnVVpaqtOnT4f6NDY2qqmpSS0tLeru7pbH41FJSYmGhoZiXjwAAEg+Ua187NmzJ+z1jh07NG/ePB04cECf/exnZYxRc3OzamtrVVZWJklqbW2V2+1WW1ub1q1bF7vKAQBAUprUPR+BQECSdN1110mS+vr65Pf7VVpaGurjdDpVVFSkrq6ucc8xMjKiYDAYtgEAgOlrwuHDGKNHH31Un/nMZ5Sfny9J8vv9kiS32x3W1+12h/Z9WH19vVwuV2jLzs6eaEkAACAJTDh8bNiwQb/4xS/0wgsvROxzOBxhr40xEW1jampqFAgEQlt/f/9ESwIAAEkgqns+xnzrW9/Sj3/8Y7322mu68cYbQ+0ej0fShRWQrKysUPvAwEDEasgYp9Mpp9M5kTIAAEASimrlwxijDRs2aNeuXfqXf/kX5ebmhu3Pzc2Vx+NRR0dHqG10dFSdnZ0qLCyMTcUAACCpRbXy8dBDD6mtrU3/8A//oIyMjNB9HC6XS6mpqXI4HKqsrJTP51NeXp7y8vLk8/mUlpam8vLyuFwAAABILlGFjy1btkiSiouLw9p37Nihr3/965KkqqoqDQ8Pq6KiIvQlY+3t7crIyIhJwQCmJr4UD8CViip8GGM+so/D4VBdXZ3q6uomWhMAAJjGeLYLAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArIrq69UBYDqY7HNoUkff1+Hf/nzrt/doeNZsHWtYNfnCgKsEKx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqqjDx2uvvaZ7771XXq9XDodDP/rRj8L2G2NUV1cnr9er1NRUFRcXq7e3N1b1AgCAJBd1+Dh9+rRuv/12tbS0jLu/sbFRTU1NamlpUXd3tzwej0pKSjQ0NDTpYgEAQPJLifaAlStXauXKlePuM8aoublZtbW1KisrkyS1trbK7Xarra1N69atm1y1AAAg6cX0no++vj75/X6VlpaG2pxOp4qKitTV1TXuMSMjIwoGg2EbAACYvmIaPvx+vyTJ7XaHtbvd7tC+D6uvr5fL5Qpt2dnZsSwJAABMMXH5tIvD4Qh7bYyJaBtTU1OjQCAQ2vr7++NREgAAmCKivufjcjwej6QLKyBZWVmh9oGBgYjVkDFOp1NOpzOWZQAAgCkspisfubm58ng86ujoCLWNjo6qs7NThYWFsRwKAAAkqahXPk6dOqVf//rXodd9fX36+c9/ruuuu0433XSTKisr5fP5lJeXp7y8PPl8PqWlpam8vDymhQMAgOQUdfjYv3+/Pve5z4VeP/roo5KktWvX6u/+7u9UVVWl4eFhVVRUaHBwUAUFBWpvb1dGRkbsqgYAAEkr6vBRXFwsY8wl9zscDtXV1amurm4ydQEAgGmKZ7sAAACrYvppFwC4WuVU7050CVfsWMOqRJeAqxwrHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jhFACuMvG6OZYbWXGlWPkAAABWET4AAIBVhA8AAGAV93zEQDJ9uRAAAInGygcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq/ieDwBATMTjO494Xsz0xMoHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqrg9WO6ZZ57RU089pZMnT2rRokVqbm7WXXfdFa/hrlg8HnwEAEguV/vvgkQ/sC8uKx8vvviiKisrVVtbq56eHt11111auXKlTpw4EY/hAABAEolL+GhqatIf/dEf6Zvf/KZuvfVWNTc3Kzs7W1u2bInHcAAAIInE/M8uo6OjOnDggKqrq8PaS0tL1dXVFdF/ZGREIyMjodeBQECSFAwGY12aJOn8yJm4nBfA1ePc6Psa+xfq3MgZnTfnE1rPdMbvgviIx7yOndMY85F9Yx4+3nnnHZ07d05utzus3e12y+/3R/Svr6/Xpk2bItqzs7NjXRoAxIxr7Idn/l8iy5j2XM2JrmB6iue8Dg0NyeVyXbZP3G44dTgcYa+NMRFtklRTU6NHH3009Pr8+fP63//9X82dO3fc/lNFMBhUdna2+vv7lZmZmehyEo75iMSchGM+IjEn4ZiPSMk0J8YYDQ0Nyev1fmTfmIeP66+/XjNmzIhY5RgYGIhYDZEkp9Mpp9MZ1vaxj30s1mXFTWZm5pR/Q9jEfERiTsIxH5GYk3DMR6RkmZOPWvEYE/MbTmfNmqWlS5eqo6MjrL2jo0OFhYWxHg4AACSZuPzZ5dFHH9WaNWu0bNkyLV++XFu3btWJEye0fv36eAwHAACSSFzCx9e+9jW9++67euKJJ3Ty5Enl5+frn/7pnzR//vx4DJcQTqdTGzdujPiT0dWK+YjEnIRjPiIxJ+GYj0jTdU4c5ko+EwMAABAjPNsFAABYRfgAAABWET4AAIBVhA8AAGAV4eO3BgcHtWbNGrlcLrlcLq1Zs0bvvffeZY8xxqiurk5er1epqakqLi5Wb29vWJ+tW7equLhYmZmZcjgc455zImPHW7zmY2RkRN/61rd0/fXXKz09XV/60pf05ptvhvXJycmRw+EI2z78rCAbnnnmGeXm5mr27NlaunSp/vVf//Wy/Ts7O7V06VLNnj1bN998s5599tmIPi+99JJuu+02OZ1O3XbbbXr55ZcnPa4tiZiPurq6iPeCx+OJ6XVNRqznpLe3V/fdd1/ov4Hm5uaYjGtLIubjanuPbNu2TXfddZeuvfZaXXvttVqxYoVef/31SY9rnYExxpi7777b5Ofnm66uLtPV1WXy8/PNPffcc9ljGhoaTEZGhnnppZfMwYMHzde+9jWTlZVlgsFgqM/3vvc9U19fb+rr640kMzg4GJOx4y1e87F+/Xrz8Y9/3HR0dJg33njDfO5znzO33367OXv2bKjP/PnzzRNPPGFOnjwZ2oaGhuJ2rePZuXOnmTlzptm2bZs5dOiQefjhh016ero5fvz4uP2PHj1q0tLSzMMPP2wOHTpktm3bZmbOnGl++MMfhvp0dXWZGTNmGJ/PZw4fPmx8Pp9JSUkx+/btm/C4tiRqPjZu3GgWLVoU9l4YGBiI+/VeiXjMyeuvv24ee+wx88ILLxiPx2O+973vTXpcWxI1H1fbe6S8vNw8/fTTpqenxxw+fNh84xvfMC6Xy7z55psTHjcRCB/GmEOHDhlJYf/o7d2710gyv/zlL8c95vz588bj8ZiGhoZQ2/vvv29cLpd59tlnI/r/7Gc/Gzd8TGTseIvXfLz33ntm5syZZufOnaE+v/nNb8w111xj9uzZE2qbP3/+uP/I2PR7v/d7Zv369WFtCxcuNNXV1eP2r6qqMgsXLgxrW7dunbnzzjtDr7/61a+au+++O6zPF77wBbN69eoJj2tLouZj48aN5vbbb59k9fERjzm52KX+O7ia3iMXu9R8XM3vEWOMOXv2rMnIyDCtra0THjcR+LOLpL1798rlcqmgoCDUduedd8rlcqmrq2vcY/r6+uT3+1VaWhpqczqdKioquuQxsRo73uI1HwcOHNAHH3wQ1sfr9So/Pz/ivJs3b9bcuXP1yU9+Uk8++aRGR0djeYmXNTo6qgMHDoTVKUmlpaWXvP69e/dG9P/CF76g/fv364MPPrhsn7FzTmRcGxI1H2OOHDkir9er3NxcrV69WkePHp3sJU1avOYkHuPakKj5GHM1v0fOnDmjDz74QNddd92Ex00Ewockv9+vefPmRbTPmzcv4gF5Fx8jKeJheW63+5LHxGrseIvXfPj9fs2aNUvXXnvtJftI0sMPP6ydO3fqZz/7mTZs2KDm5mZVVFRM6pqi8c477+jcuXNR/W/r9/vH7X/27Fm98847l+0zds6JjGtDouZDkgoKCvTcc8/ppz/9qbZt2ya/36/CwkK9++67sbi0CYvXnMRjXBsSNR8S75Hq6mp9/OMf14oVKyY8biJM6/Ax3o1IH972798vSXI4HBHHG2PGbb/Yh/dfyTEfdY6JnuejTNX5+HCfRx55REVFRVqyZIm++c1v6tlnn9X27dut/2MS7bWM1//D7Vdyzli8p+IhEfOxcuVK3XfffVq8eLFWrFih3bt3S5JaW1sndhExFo85ice4tiRiPq7m90hjY6NeeOEF7dq1S7Nnz57UuLbF5dkuU8WGDRu0evXqy/bJycnRL37xC7399tsR+/7nf/4nIj2OGbub2u/3KysrK9Q+MDBwyWMudZ5ox56oRM+Hx+PR6OioBgcHw1Y/BgYGLvvE4zvvvFOS9Otf/1pz5869bP2xcP3112vGjBkR/y/hcv/bejyecfunpKSEar5Un7FzTmRcGxI1H+NJT0/X4sWLdeTIkYlcSszEa07iMa4NiZqP8Vwt75Hvfve78vl8euWVV7RkyZJJjZsI03rl4/rrr9fChQsvu82ePVvLly9XIBAI+7jSf/zHfygQCFzyl2Jubq48Ho86OjpCbaOjo+rs7LzsL9IPm8jYE5Xo+Vi6dKlmzpwZ1ufkyZP6r//6r8tea09PjySFhZp4mjVrlpYuXRpWpyR1dHRcss7ly5dH9G9vb9eyZcs0c+bMy/YZO+dExrUhUfMxnpGRER0+fNjae+FS4jUn8RjXhkTNx3iuhvfIU089pb/+67/Wnj17tGzZskmPmxBWb2+dwu6++26zZMkSs3fvXrN3716zePHiiI+WLliwwOzatSv0uqGhwbhcLrNr1y5z8OBBc//990d8tPTkyZOmp6fHbNu2zUgyr732munp6THvvvtuVGPbFq/5WL9+vbnxxhvNK6+8Yt544w3z+c9/Puyjtl1dXaapqcn09PSYo0ePmhdffNF4vV7zpS99yc6F/9bYR9W2b99uDh06ZCorK016ero5duyYMcaY6upqs2bNmlD/sY/IPfLII+bQoUNm+/btER+R+/d//3czY8YM09DQYA4fPmwaGhou+VHbS42bKImajz/7sz8zr776qjl69KjZt2+fueeee0xGRkbC58OY+MzJyMiI6enpMT09PSYrK8s89thjpqenxxw5cuSKx02URM3H1fYe2bx5s5k1a5b54Q9/eMmvI5iq75GLET5+69133zUPPPCAycjIMBkZGeaBBx6I+FisJLNjx47Q6/Pnz5uNGzcaj8djnE6n+exnP2sOHjwYdszGjRuNpIjt4vNcydi2xWs+hoeHzYYNG8x1111nUlNTzT333GNOnDgR2n/gwAFTUFBgXC6XmT17tlmwYIHZuHGjOX36dDwvd1xPP/20mT9/vpk1a5b51Kc+ZTo7O0P71q5da4qKisL6v/rqq+aOO+4ws2bNMjk5OWbLli0R5/z7v/97s2DBAjNz5kyzcOFC89JLL0U1biIlYj7Gvitm5syZxuv1mrKyMtPb2xuX65uIWM9JX1/fuP9efPg8V8t75Erm42p7j8yfP3/cOdm4ceMVjzsVOIz57d0sAAAAFkzrez4AAMDUQ/gAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABg1f8HemSwncbCvp0AAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X = torch.randn(100, 3)\n",
"Y = torch.randn(100, 3) * 1.3\n",
"Z = torch.cat((X, Y))\n",
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n",
"\n",
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n",
"\n",
"plt.hist(res.permuted_estimates, bins=\"auto\")\n",
"plt.axvline(res.estimate, color=\"r\")\n",
"res.p_value"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment