Skip to content

Instantly share code, notes, and snippets.

@SomeoneSerge
Last active October 5, 2020 19:42
Show Gist options
  • Save SomeoneSerge/9663aca1ea70b03ec52385eb881a30ff to your computer and use it in GitHub Desktop.
Save SomeoneSerge/9663aca1ea70b03ec52385eb881a30ff to your computer and use it in GitHub Desktop.
Fair inverse transform implementation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inverse transform"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See evaluated notebook at: https://gist.github.com/newkozlukov/9663aca1ea70b03ec52385eb881a30ff"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example CDF"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"torch.manual_seed(42)\n",
"\n",
"cdf = torch.tensor([0, 0, 0, .1, .1, .2, .75, 1.0])\n",
"t_left = torch.arange(len(cdf), dtype=cdf.dtype)\n",
"sizes = torch.ones_like(t_left)\n",
"t = t_left + torch.rand_like(sizes) * sizes\n",
"\n",
"plt.fill_between(torch.cat((t_left, (t_left + sizes)[-1:])),\n",
" torch.cat((cdf, cdf[-1:])),\n",
" alpha=.5, color=\"orange\", step=\"post\", label=\"CDF\")\n",
"plt.scatter(t, cdf, color=\"red\", label=\"Sampled $t$'s\")\n",
"plt.vlines(t, 0, cdf, color=\"red\")\n",
"plt.vlines(t_left, 0, cdf,\n",
" color=\"magenta\", linestyles=\"dashed\", linewidth=.5,\n",
" label=\"$[t_{\\\\mathrm{left}}..t_{\\\\mathrm{right}}]$\")\n",
"plt.vlines((t_left + sizes)[-1:], 0, cdf[-1:],\n",
" color=\"magenta\", linestyles=\"dashed\", linewidth=.5)\n",
"#plt.step(t, cdf, color=\"red\", where=\"mid\")\n",
"plt.title(\"CDF\")\n",
"plt.xlabel(\"$t$\")\n",
"plt.ylabel(\"$u$\")\n",
"plt.yticks(cdf)\n",
"plt.xticks(torch.sort(torch.cat((t_left, t, (t_left + sizes)[-1:]))).values, rotation=90)\n",
"_ = plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first density we encounter is in bucket $[t_3..t_4]$, so the minimal value we'd want to sample would be the left bound of that bucket, that is $3$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inverse transform implementation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def inverse_transform(t_left, sizes, cdf, min_divisor=1e-6):\n",
" \"\"\"Assuming sample dimension is the innermost, as required by :func:`searchsorted`\"\"\"\n",
" n_samples = t_left.shape[-1]\n",
" \n",
" def transform(u):\n",
" \"\"\"transform(u)\n",
" - :obj:`u` uniform sample of type :obj:`torch.Tensor`\"\"\"\n",
" \n",
" iright = torch.searchsorted(cdf, u, right=True).clamp(1, n_samples - 1)\n",
" ileft = (iright - 1).clamp(0, n_samples - 2) # we must put $u$ _before_ found indices\n",
" \n",
" tleft = torch.gather(t_left, -1, ileft) + torch.gather(sizes, -1, ileft)\n",
" s = torch.gather(sizes, -1, iright)\n",
" qleft = torch.gather(cdf, -1, ileft)\n",
" qright = torch.gather(cdf, -1, iright)\n",
" \n",
" u = (u - qleft) / (qright - qleft).clamp(min_divisor)\n",
" t = tleft + u * s\n",
" return t\n",
" \n",
" return transform"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why `right=True`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With `right=False` if we sample (e.g. deterministically, via `linspace`) $u=0$, the `searchsorted` would return the index $0$. After clamping, this corresponds to the interval `t_left[0]..sizes[0]` which is known not to contain any density (first interval that does have density is `t_left[2] + sizes[2]`)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.0000, 0.0000, 0.0000, 0.1000, 0.1000, 0.2000, 0.7500, 1.0000])\n",
"tensor([0, 3, 3, 6, 6, 7])\n"
]
}
],
"source": [
"print(cdf)\n",
"print(torch.searchsorted(cdf, torch.tensor([0, .05, .1, .7, .75, 1.0])))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0.0000, 0.0000, 0.0000, 0.1000, 0.1000, 0.2000, 0.7500, 1.0000])\n",
"tensor([3, 3, 5, 6, 7, 8])\n"
]
}
],
"source": [
"print(cdf)\n",
"print(torch.searchsorted(cdf, torch.tensor([0, .05, .1, .7, .75, 1.0]), right=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Eval"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"T = inverse_transform(t_left, sizes, cdf)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"u = torch.linspace(0, 1, 1000)\n",
"\n",
"f, (ax_t_of_u, ax_u_of_t) = plt.subplots(1, 2, figsize=(10, 5))\n",
"\n",
"ax_t_of_u.scatter(u, T(u))\n",
"ax_t_of_u.set_xlabel(\"$u$\")\n",
"ax_t_of_u.set_ylabel(\"$T(u)$\")\n",
"ax_t_of_u.hlines(t_left, 0, 1, linewidth=0.5, label=\"t_left, t_right\", color=\"magenta\")\n",
"ax_t_of_u.set_xticks(cdf)\n",
"\n",
"ax_u_of_t.scatter(T(u), u, )\n",
"ax_u_of_t.set_xlabel(\"$t$\")\n",
"ax_u_of_t.set_ylabel(\"$T^{-1}(t)$\")\n",
"ax_u_of_t.set_yticks(cdf)\n",
"ax_u_of_t.vlines(t_left, 0, 1, linewidth=0.5, label=\"t_left, t_right\", color=\"magenta\")\n",
"_= f.suptitle(\"Uniformly sampled $u$'s and corresponding $T(u)$'s\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000],\n",
" [0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000],\n",
" [1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000],\n",
" [0.8823, 1.9150, 2.3829, 3.9593, 4.3904, 5.6009, 6.2566, 7.7936],\n",
" [0.0000, 0.0000, 0.0000, 0.1000, 0.1000, 0.2000, 0.7500, 1.0000]])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.stack((torch.arange(len(cdf)), t_left, t_left + sizes, t, cdf))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The $u$-values of $[0..0.1)$ interpolate between $t$-values of $[3,4]$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3.0000, 3.0100, 4.0000])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"T(torch.tensor([0, .001, 0.1-1e-8]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is because `searchsorted(cdf, .)` query for these values returns the bucket"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3, 3, 3])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.searchsorted(cdf, torch.tensor([0, .001, 0.1-1e-8]), right=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note how on the right boundary of this bucket we have the jump"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3, 5])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.searchsorted(cdf, torch.tensor([.1-1e-8, .1]), right=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is that the query for $u=0.1$ skips the interval $[2,3]$, which has the probability measure of $\\mathrm{cdf}_3 - \\mathrm{cdf}_2 = 0$"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([5.0000, 5.9990, 6.0000, 6.9818])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"T(torch.tensor([.1, .1999, .2, .74]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Eval stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we sample $t$ using inverse transform and overlay the CDFs"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA6IAAAILCAYAAADynCEVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg0UlEQVR4nO3df5Tld13f8deb3WD4HSUrP/KDiTVQIwrS3Yj1RxFEEhCitrZEK8rRppxD/FGOQrTU1h9tUauHQ0FzUkRU1KAYadAo1uNBpQpkEQTCD13Cwm7Cjw0gQmQbA+/+ce+yN5PZmbu7dz93ZvbxOGdO5t77nXvf90fms8/5fudOdXcAAABglHssewAAAABOL0IUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCrBFVNVNVfW4k91mjtvZX1XfcDLXsRlV1cur6qeWPcc8quoRVfWWqvpkVX3/nF+zrZ63k7k/W+m5BjhdCVGAkzT9B/Onq+pTMx8vXvTtdPeXdvfrTnabk7Hqvn64qn65qu47c9kdVXX2qq95a1V1Va0c43qOfDz0VM09r6r69qraO53ng1X1B1X1Natm/mRV/V1V/UVVPauq7jHz9Yu6X89N8rruvl93v+gYs26r8Bypqm6pqkfPnN4/+/oE4NQTogCL8dTuvu/Mx5Ujb7yqdg68uad2932TPCbJniTPn7nsfUkun5nry5Lca73rmfm49ZRNPIeqek6SFyb5b0kelOT8JL+Q5LKZzZ7a3fdL8rAkL0jyvCS/tOqqFnG/HpbkphP4OjYw/UHJFyZ517JnATidCVGAU2i6p+WHq+ptVXV7Vf1SVT1ouqftk1X1x1X1+au2/5GqemdVfXy6x/HMmcu+YdW2z6uqtyW5vap2zm5TVedV1XVVdaiqPjq7l7aqrqqq905neGdVfcvx3rfuviXJHyR55MzZv5bkGTOnvyvJrx7vdc875/T+/tD08f1EVb1y5vH6iqr6q+nXvjLJmevczgOS/ESSZ3f3dd19e3f/Y3e/prt/eI37/onuvj7Jv0nyXVX1yNXbzHHfvqSqXjfdu3pTVT1tev6fJPn6JC+e7lF9+Bpf+2uZhPJrpts8d3rRo9d6LKZf89Cq+p3p6+F9tc4hv9PX1S3Tx+49VfWE6fnzPB9zvd7Xe62vMc+6s8/7XFfVFyc5kMm/fz46/f9i56pt1rzvACyWEAU49f5lkicmeXiSp2YSbz+a5OxMvg+vDoLvSPKkJP9k+jXPz7FdnuQpSc7q7juPnFlVO5L8XpL3J1lJck6Sa2e+7r1JvjbJA5L8eJJXVNVDjudOVdV5SZ6c5C0zZ78hyf2nkbUjk1B7xfFc7yrzzPmvk1yS5IIkX57ku6vqnklenUkYf0GS387keTiWr8okXn73eIbr7jclOTidcW5VdUaS1yT5o0z2zn1fkl+vqkd09+OT/HmSK6d7VP9mjdv9ziQfyNG9rz8zvehuj8X09u4xvb2/zuS18IQkP1hVT1pjtkckuTLJnune3ycl2T+9eJ7n43he7xu+1jea/Xie6+7el+SHkrxq+rg9sLvv7O6V7t6/wX0HYIGEKMBivHq6Z+vIx7+buex/dveHp3sQ/zzJG7v7Ld39/zIJn69YdV0v7u4D3f2xJP81M4e6ruFF020/ver8i5M8NMkPT/fuHe7u1x+5sLt/u7tv7e7Pdvcrk/zt9Gvmvq9JXp/kTzM5lHXWkb2iT0zy7iS3rHc9049Xr7XBnHO+aLrNxzIJlkcneWySM5K8cLpn81VJblznPj0wyW2zMX8cbs0kgOa+X9P57pvkBd19R3f/SSY/OFjvuZ7HWo9FMjmEeld3/8T09m5O8r+SPH2N6/hMks9LclFVndHd+7v7vcncz8fxvN7nea1vNPvxPtePSvLWY1x2zPsOwGKN/J0igO3sm7v7j49x2YdnPv/0Gqfvu2r7AzOfvz+ToDyWA8c4/7wk7z9WWFXVM5I8J5O9pZnOcPZa265hvfuaTEL0zzLZK7feYbkbXc+8c35o5vN/yOTxemiSW7q7Zy57/zo39dEkZ1fVzhOI0XOSfGzm9Ib3azrfge7+7Kr5zjnO215trccimfzO6UOnP0A4YkcmoXgX3b2vqn4wyX9J8qVV9dokz+nuW+d8Po7n9T7Pa32j2Y/3uX50JntQ72a9+77O9QFwAuwRBdh8zpv5/PxM9rgdSx/j/ANJzl/9+29JUlUPy2SP0pVJHtjdZyV5R5I6oWlXD9T9/kzetOjJSa470es5yTk/mOScqprd9vx1tv/LJIeTfPNxzrgnk3h8/UbbrnJrkvNq5h13p/Mda+/xWo713K/lQJL3dfdZMx/36+4nr3nF3b/R3V+TSQR2kp8+Ra+beV7rG80+93M9fbwfmWPvEV3zvs99bwCYmxAF2HyeXVXnVtUXZPK7da88get4Uyb/QH9BVd2nqs6sqq+eXnafTP6BfShJquqZuesbDi3C9yR5fHfffhLXcTJz/mWSO5N8f03exOlbs86hx939iSQ/luQlVfXNVXXvqjqjqi6tqp9ZvX1V3b+qvimT37t9RXe//bjuWfLGJLcnee70dh6Xye9TXrveF63y4SRfNOe2b0ry99M34rlXVe2oqkdOQ/ouavI3TB9fVZ+XSZx/OpNDVk/F62ae1/pGsx/Pc32v6cea//5Z574DsGBCFGAxjrx76ZGP43rTm1V+I5M3sbl5+vFTx3sF3f2ZTMLmizN5U5uDmbxxULr7nUl+LpN/wH84yZcl+b8nMe9at//e7t57ktdxwnN29x1JvjWTN+v5eCb3fd29s93985kcdvr8TGLrQCZ7/149s9lrquqT08v+Y5KfT/LMOe/S6vmeluTSJLdl8mdintHd7z6Oq/nvSZ4//V3UH9rg9o68Hh6dyd7q25K8NJM3HVrt8zL50zS3ZXKo7xcm+dFT9LrZ8LW+0ezH81xPfzBydZJ3VtXBNTZZ876f4H0DYB1111+pAGCZqmp/ku+d43cMYUvzWgc4vdkjCgAAwFBCFAAAgKEcmgsAAMBQ9ogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADDUzmXd8Nlnn90rKyvLunkAtpk3v/nNt3X3rmXPsZVZmwFYpPXW5qWF6MrKSvbu3busmwdgm6mq9y97hq3O2gzAIq23Njs0FwAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUBuGaFW9rKo+UlXvOMblVVUvqqp9VfW2qnrM4scEAI6wNgOw1c2zR/TlSS5Z5/JLk1w4/bgiyS+e/FgAwDpeHmszAFvYhiHa3X+W5GPrbHJZkl/tiTckOauqHrKoAQGAu7I2A7DV7VzAdZyT5MDM6YPT8z64esOquiKTn8zm/PPPX8BNH4errko+9KGxtwmwVT34wckLXrDsKThx1maAAVYe/G3LHuGU2P+Cp5zy21hEiNYa5/VaG3b3NUmuSZLdu3evuc0p86EPJSsrQ28SYMvav3/ZE3ByrM0AIxxe9gBb1yLeNfdgkvNmTp+b5NYFXC8AcGKszQBsaosI0euTPGP6Dn2PTfKJ7r7boT8AwDDWZgA2tQ0Pza2q30zyuCRnV9XBJP85yRlJ0t1XJ7khyZOT7EvyD0meeaqGBQCszQBsfRuGaHdfvsHlneTZC5sIAFiXtRmArW4Rh+YCAADA3IQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAw1M5lDwAAADDayuE9yx7htGaPKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIbauewBAAAA5rVyeM+yR2AB7BEFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGConcseAAAA2P5WDu9Z9ghsIvaIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEPNFaJVdUlVvaeq9lXVVWtc/oCqek1V/XVV3VRVz1z8qADAEdZmALayDUO0qnYkeUmSS5NclOTyqrpo1WbPTvLO7n5Ukscl+bmquueCZwUAYm0GYOubZ4/oxUn2dffN3X1HkmuTXLZqm05yv6qqJPdN8rEkdy50UgDgCGszAFvaPCF6TpIDM6cPTs+b9eIkX5Lk1iRvT/ID3f3Z1VdUVVdU1d6q2nvo0KETHBkATnvWZgC2tHlCtNY4r1edflKStyZ5aJJHJ3lxVd3/bl/UfU137+7u3bt27TrOUQGAKWszAFvaPCF6MMl5M6fPzeSnq7OemeS6ntiX5H1J/uliRgQAVrE2A7ClzROiNya5sKoumL7JwdOTXL9qmw8keUKSVNWDkjwiyc2LHBQA+BxrMwBb2s6NNujuO6vqyiSvTbIjycu6+6aqetb08quT/GSSl1fV2zM5XOh53X3bKZwbAE5b1mYAtroNQzRJuvuGJDesOu/qmc9vTfKNix0NADgWazMAW9k8h+YCAADAwghRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMtXPZAwAAAJvbyuE9yx6BbcYeUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgqJ3LHgDYHFYO71n2CAu3/8wblz0CAABrsEcUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQc4VoVV1SVe+pqn1VddUxtnlcVb21qm6qqj9d7JgAwCxrMwBb2c6NNqiqHUlekuSJSQ4mubGqru/ud85sc1aSX0hySXd/oKq+8BTNCwCnPWszAFvdPHtEL06yr7tv7u47klyb5LJV23x7kuu6+wNJ0t0fWeyYAMAMazMAW9o8IXpOkgMzpw9Oz5v18CSfX1Wvq6o3V9Uz1rqiqrqiqvZW1d5Dhw6d2MQAgLUZgC1tnhCtNc7rVad3JvlnSZ6S5ElJ/lNVPfxuX9R9TXfv7u7du3btOu5hAYAk1mYAtrgNf0c0k5+ynjdz+twkt66xzW3dfXuS26vqz5I8KsnfLGRKAGCWtRmALW2ePaI3Jrmwqi6oqnsmeXqS61dt87+TfG1V7ayqeyf5yiTvWuyoAMCUtRmALW3DPaLdfWdVXZnktUl2JHlZd99UVc+aXn51d7+rqv4wyduSfDbJS7v7HadycICNrBzes+wRPmf/mTcuewS2EWszAFvdPIfmprtvSHLDqvOuXnX6Z5P87OJGAwCOxdoMwFY2V4gCAABbz2Y6OghmzfM7ogAAALAwQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAy1c9kDAAAAd7dyeM+yR4BTxh5RAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYaueyBwBOzsrhPcseAQAAjos9ogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDzRWiVXVJVb2nqvZV1VXrbLenqj5TVf9qcSMCAKtZmwHYyjYM0arakeQlSS5NclGSy6vqomNs99NJXrvoIQGAo6zNAGx18+wRvTjJvu6+ubvvSHJtksvW2O77kvxOko8scD4A4O6szQBsafOE6DlJDsycPjg973Oq6pwk35Lk6vWuqKquqKq9VbX30KFDxzsrADBhbQZgS5snRGuN83rV6RcmeV53f2a9K+rua7p7d3fv3rVr15wjAgCrWJsB2NJ2zrHNwSTnzZw+N8mtq7bZneTaqkqSs5M8uaru7O5XL2JIAOAurM0AbGnzhOiNSS6sqguS3JLk6Um+fXaD7r7gyOdV9fIkv2ehA4BTxtoMwJa2YYh2951VdWUm77i3I8nLuvumqnrW9PJ1f/cEAFgsazMAW908e0TT3TckuWHVeWsuct393Sc/FgCwHmszAFvZPG9WBAAAAAsjRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACG2rnsAQBOByuH9yzkevafeeNCrgcAYJmEKAAALNCifvgI25lDcwEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMNTOZQ8AwPxWDu8Zc0MP3pNc9ftDbmr/C54y5HYAgM3DHlEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFBCFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGCouUK0qi6pqvdU1b6qumqNy7+jqt42/fiLqnrU4kcFAI6wNgOwlW0YolW1I8lLklya5KIkl1fVRas2e1+Sf9HdX57kJ5Ncs+hBAYAJazMAW908e0QvTrKvu2/u7juSXJvkstkNuvsvuvvj05NvSHLuYscEAGZYmwHY0uYJ0XOSHJg5fXB63rF8T5I/WOuCqrqiqvZW1d5Dhw7NPyUAMMvaDMCWNk+I1hrn9ZobVn19Jovd89a6vLuv6e7d3b17165d808JAMyyNgOwpe2cY5uDSc6bOX1ukltXb1RVX57kpUku7e6PLmY8AGAN1mY4RVYO71n2CHBamGeP6I1JLqyqC6rqnkmenuT62Q2q6vwk1yX5zu7+m8WPCQDMsDYDsKVtuEe0u++sqiuTvDbJjiQv6+6bqupZ08uvTvJjSR6Y5BeqKknu7O7dp25sADh9WZsB2OrmOTQ33X1DkhtWnXf1zOffm+R7FzsaAHAs1mYAtrJ5Ds0FAACAhRGiAAAADCVEAQAAGEqIAgAAMJQQBQAAYKi53jUXAAA2s5XDe5Y9AnAc7BEFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMJQQBQAAYCghCgAAwFA7lz0AAACnt5XDe5Y9AjCYPaIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoYQoAAAAQwlRAAAAhhKiAAAADCVEAQAAGEqIAgAAMNTOZQ8Ao60c3rPsEQAA4LS2LUJ05arf33ijB39bcvjUz7Io+8+8cdkjAACsyw93gRPl0FwAAACG2hZ7RLejRf2E0Z5VAABgs7FHFAAAgKHsEQUA2IDfhQRYLHtEAQAAGEqIAgAAMJRDc7e5RRxK5A2PAACARbJHFAAAgKGEKAAAAEMJUQAAAIYSogAAAAwlRAEAABjKu+YCANvaIt5BHoDFskcUAACAoYQoAAAAQzk0lw0t6pCm/WfeuJDrAQAAtjZ7RAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABD+fMtAMCmtKg/HwbA5mOPKAAAAEMJUQAAAIYSogAAAAwlRAEAABhKiAIAADCUEAUAAGAof76FLcPb+AMAwPZgjygAAABDCVEAAACGEqIAAAAMJUQBAAAYypsVAQALt3LV72+80YO/LTl86mcBYPOxRxQAAIChhCgAAABDCVEAAACGEqIAAAAMJUQBAAAYSogCAAAwlBAFAABgKCEKAADAUEIUAACAoeYK0aq6pKreU1X7quqqNS6vqnrR9PK3VdVjFj8qAHCEtRmArWzDEK2qHUlekuTSJBclubyqLlq12aVJLpx+XJHkFxc8JwAwZW0GYKubZ4/oxUn2dffN3X1HkmuTXLZqm8uS/GpPvCHJWVX1kAXPCgBMWJsB2NJ2zrHNOUkOzJw+mOQr59jmnCQfnN2oqq7I5KeySfKpqnrPcU17Ys5Octv5yTlnJGcMuL1N6xPJmQ9IDi/r9r/4JL9+nhfrvJb9WGwmHoujPBZHfSw54+PJ3464rfrphV3VwxZ2TZuftXmb8H3nKI/FUR6LozwWR223tXmef9vXGuf1CWyT7r4myTVz3ObCVNXe7t498jY3q6rae8hjkcRjMctjcZTH4ijfOzc9a/M24fvOUR6LozwWR3ksjtpu3zvnOTT3YJLzZk6fm+TWE9gGAFgMazMAW9o8IXpjkgur6oKqumeSpye5ftU21yd5xvQd+h6b5BPd/cHVVwQALIS1GYAtbcNDc7v7zqq6Mslrk+xI8rLuvqmqnjW9/OokNyR5cpJ9Sf4hyTNP3cjHbejhRpucx+Ioj8VRHoujPBZHeSw2MWvztuKxOMpjcZTH4iiPxVHb6rGo7rv9uggAAACcMvMcmgsAAAALI0QBAAAYatuGaFWdWVVvqqq/rqqbqurHlz3TMlXVjqp6S1X93rJnWbaq2l9Vb6+qt1bV3mXPs0xVdVZVvaqq3l1V76qqr1r2TMtQVY+Yvh6OfPx9Vf3gsudahqr6D9Pvme+oqt+sqjOXPRPbh7X5rqzNR1mbj7I2T1ibj9qua/O2/R3Rqqok9+nuT1XVGUlen+QHuvsNSx5tKarqOUl2J7l/d3/TsudZpqran2R3d9+27FmWrap+Jcmfd/dLp++8ee/u/rslj7VUVbUjyS1JvrK737/seUaqqnMy+V55UXd/uqp+K8kN3f3y5U7GdmFtvitr81HW5qOszXdnbd6ea/O23SPaE5+anjxj+rE9q3sDVXVukqckeemyZ2HzqKr7J/m6JL+UJN19x+m+0E09Icl7T7eFbsbOJPeqqp1J7h1/d5IFsjYfZW1mLdbmY7I2b8O1eduGaPK5Q17emuQjSf5Pd79xySMtywuTPDfJZ5c8x2bRSf6oqt5cVVcse5gl+qIkh5L88vTQsJdW1X2WPdQm8PQkv7nsIZahu29J8j+SfCDJBzP5u5N/tNyp2G6szZ/zwlibZ1mbJ6zNa7M2b8O1eVuHaHd/prsfneTcJBdX1SOXPNJwVfVNST7S3W9e9iybyFd392OSXJrk2VX1dcseaEl2JnlMkl/s7q9IcnuSq5Y70nJND4F6WpLfXvYsy1BVn5/ksiQXJHlokvtU1b9d7lRsN9Zma/MxWJsnrM2rWJu379q8rUP0iOkhDa9LcslyJ1mKr07ytOnvXlyb5PFV9YrljrRc3X3r9L8fSfK7SS5e7kRLczDJwZm9Ea/KZPE7nV2a5K+6+8PLHmRJviHJ+7r7UHf/Y5LrkvzzJc/ENmVttjbPsjZ/jrX57qzN23Rt3rYhWlW7quqs6ef3yuRJfPdSh1qC7v6R7j63u1cyOazhT7p7W/wU5URU1X2q6n5HPk/yjUnesdyplqO7P5TkQFU9YnrWE5K8c4kjbQaX5zQ99GfqA0keW1X3nr6pzBOSvGvJM7GNWJsnrM13ZW0+ytq8JmvzNl2bdy57gFPoIUl+ZfouW/dI8lvdfdq/PTp5UJLfnfx/nJ1JfqO7/3C5Iy3V9yX59elhLzcneeaS51maqrp3kicm+ffLnmVZuvuNVfWqJH+V5M4kb0lyzXKnYpuxNrMWa/NdWZunrM3be23etn++BQAAgM1p2x6aCwAAwOYkRAEAABhKiAIAADCUEAUAAGAoIQoAAMBQQhQAAIChhCgAAABD/X+CuSIQfx2OIgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 1152x576 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"torch.manual_seed(42)\n",
"u = torch.rand(1000)\n",
"Tu = T(u)\n",
"Tu = torch.sort(Tu).values\n",
"\n",
"# _1, _2 = np.histogram(Tu, bins=torch.cat((t_left, (t_left + sizes)[-1:])), normed=True)\n",
"\n",
"_1, _2 = np.histogram(Tu, bins=np.linspace(Tu.min(), Tu.max(), 20), density=True)\n",
"\n",
"f, (axd, axc) = plt.subplots(1, 2, figsize=(16, 8))\n",
"axd.bar(_2[1:], _1)\n",
"axc.bar(_2[1:], np.cumsum(_1 / _1.sum()))\n",
"axd.fill_between([Tu.min(), Tu.max()], 1, color=\"red\", alpha=.5)\n",
"axc.fill_between([Tu.min(), Tu.max()], 1, color=\"red\", alpha=.5)\n",
"\n",
"_ = f.suptitle(\"Empirical PMF and CDF of the sampled $t$'s\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"_1, _2 = np.histogram(Tu, bins=np.linspace(Tu.min(), Tu.max(), 2 + len(t_left)), density=True)\n",
"ecdf = np.cumsum(_1 / _1.sum())\n",
" \n",
"f, ax = plt.subplots(figsize=(8, 8))\n",
"ax.bar(_2[:-1], ecdf, alpha=.5, color=\"red\", label=\"ECDF of the sample\")\n",
"ax.bar(t_left, width=sizes, height=cdf, alpha=.75, color=\"green\", label=\"True CDF\")\n",
"_ = ax.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "nerfs",
"language": "python",
"name": "nerfs"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment