Skip to content

Instantly share code, notes, and snippets.

@fepegar
Created February 10, 2021 15:52
Show Gist options
  • Save fepegar/9e029c85827f48360ecc1b62b2530fa7 to your computer and use it in GitHub Desktop.
Save fepegar/9e029c85827f48360ecc1b62b2530fa7 to your computer and use it in GitHub Desktop.
Test label sampler.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Test label sampler.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyNrlDa49HLQ/jzNDBnNeFDL",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/fepegar/9e029c85827f48360ecc1b62b2530fa7/test-label-sampler.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "KrbGyxQCb62F"
},
"source": [
"!pip install --quiet torchio"
],
"execution_count": 85,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ETNalyapb8t3",
"outputId": "3cb5991a-7d4f-41a8-a1e5-52391274dcbc"
},
"source": [
"import torch\n",
"import torchio as tio\n",
"import matplotlib.pyplot as plt\n",
"\n",
"size = 5\n",
"tensor = torch.zeros(1, size, size, 1)\n",
"tensor[0, 2, 2] = 1\n",
"label = tio.LabelMap(tensor=tensor)\n",
"\n",
"subject = tio.Subject(label=label)\n",
"patch_size = 2, 2, 1\n",
"sampler = tio.LabelSampler(patch_size, label_probabilities={0: 1, 1:1})\n",
"values = torch.as_tensor([patch.label.data[0, 1, 1, 0] for patch in sampler(subject, 10000)])\n",
"print(values.mean())"
],
"execution_count": 86,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(0.6166)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 287
},
"id": "qWNBeroSctgu",
"outputId": "f06fa6be-b3b0-46b0-e32b-c73505a50ee9"
},
"source": [
"ax = plt.imshow(tensor[0, ..., 0])\n",
"plt.colorbar(ax)"
],
"execution_count": 87,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f6b2e792630>"
]
},
"metadata": {
"tags": []
},
"execution_count": 87
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAAD8CAYAAAA11GIZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQvklEQVR4nO3dbYxc1X3H8e+PxbDhwUHUVCXYiZHqRrVoC5FlU/lFKA+NIZH9olWFI9KmQvWbUJGGJgK1Iil9lValUSUr7TZYpEkKpSSqVtSpSxIjlAqIl4da2A6t5bbBBMk1T4FGYHv31xd3jIbV7sxddu7Omb2/j3Slebhz5m8Lfj733DPnyDYREaU5Y9gFRETMJeEUEUVKOEVEkRJOEVGkhFNEFCnhFBFFSjhFxKJJ2iXpmKRn53lfkv5K0mFJ+yV9qF+bCaeIGIR7gS093r8eWNc5dgBf7tdgwikiFs32o8DLPU7ZBvydK48DF0i6uFebZw6ywNPO0tke59wmmo4I4E3+jxN+S4tp4yO/dq5fenm61rlP7n/rAPBm10sTticW8HWXAM93PT/aee3F+T7QSDiNcy6bdE0TTUcE8IS/u+g2Xnp5mh/seX+tc8cu/s83bW9Y9JcuQCPhFBHlMzDDzFJ93QvAmq7nqzuvzStjThEtZcxJT9c6BmAS+O3OXbsrgddsz3tJB+k5RbTaoHpOku4DrgJWSToKfB5YAWD7r4HdwA3AYeCnwO/2azPhFNFSxkwPaMkk29v7vG/gUwtpM+EU0WIzlLueW8IpoqUMTCecIqJE6TlFRHEMnCx4me6EU0RLGeeyLiIKZJguN5sSThFtVc0QL1fCKaK1xDSL+u1woxJOES1VDYgnnCKiMNU8p4RTRBRoJj2niChNek4RUSQjpgteNalWZZK2SHqus3PC7U0XFRFLY8aqdQxD356TpDFgJ3Ad1bq/+yRN2j7YdHER0RwjTnhs2GXMq07PaSNw2PYR2yeA+6l2UoiIEVZNwjyj1jEMdcac5to1YdPskyTtoNqPinHOGUhxEdGsVgyId7aJmQBYqQsL/sVORADYYtrlDojXCacF75oQEaNhZsR7TvuAdZIupQqlG4GPN1pVRDSuGhAvdzZR38psn5J0C7AHGAN22T7QeGUR0ajTA+KlqhWbtndTbe0SEcvIdH6+EhGlKX2GeMIposVmRvxuXUQsQ9UPfxNOEVEYI04W/POVhFNES9mM/CTMiFiWNPKTMCNiGTLpOUVEoTIgHhHFMcNbSK6OhFNES1VbQ5UbAeVWFhENy6aaEVEgkxniEVGokntO5cZmRDTKFjM+o9bRT78dmiS9X9JeSU9L2i/phn5tpucU0VLVgPjif75Sc4emPwYesP1lSeuplmBa26vdhFNEaw1sDfG3d2gCkHR6h6bucDKwsvP4vcCP+zWacIpoqWpAvPaY0ypJU13PJzqbmkC9HZq+APyrpN8HzgWu7feFCaeIFlvADPHjtjcs4qu2A/fa/gtJvwp8TdJltmfm+0DCKaKlBjhDvM4OTTcDWwBsPyZpHFgFHJuv0dyti2ixAe34+/YOTZLOotqhaXLWOT8CrgGQ9IvAOPC/vRpNzymipWw4ObP4/sl8OzRJuguYsj0J3Ab8raQ/oBru+qTtnpvvJpwiWqq6rBvMxdNcOzTZvrPr8UFg80LaTDhFtFjJM8QTThEttcCpBEsu4RTRWoO7rGtCwimixbKGeEQUp7pbl62hIqIwWaY3IoqVy7qIKE7u1kVEsXK3LiKKY4tTCaeIKFEu6yKiOKWPOfXt00naJemYpGeXoqCIWDozVq1jGOpccN5LZ5GoiFg+Ts9zKjWc+l7W2X5U0trmS4mIpZZ5ThFRHBtODWCxuaYMLJwk7QB2AIxzzqCajYgGlTwgPrBw6mwTMwGwUhf2XH4zIoYvv62LiGK54HCqM5XgPuAx4IOSjkq6ufmyImIpzKBaxzDUuVu3fSkKiYilZbdkzCkiRo2YbsPduogYPSWPOSWcIlqq9N/WJZwi2srVuFOpEk4RLZafr0REcZwB8YgoVS7rIqJIuVsXEcWxE04RUahMJYiIImXMKSKKY8RM7tZFRIkK7jjV2uAgIpajzoB4naMfSVskPSfpsKTb5znntyQdlHRA0t/3azM9p4g2G0DXSdIYsBO4DjgK7JM0aftg1znrgDuAzbZfkfSz/dpNzymixQbUc9oIHLZ9xPYJ4H5g26xzfg/YafuV6nt9rF+j6TmNmD0/fmbYJSzIR953+bBLiHkYmJmpPZVglaSprucTnX0DAC4Bnu967yiwadbnfwFA0r8BY8AXbP9Lry9MOEW0lYH685yO296wiG87E1gHXAWsBh6V9Eu2X53vA7msi2gxu97RxwvAmq7nqzuvdTsKTNo+afu/gP+gCqt5JZwi2sw1j972AeskXSrpLOBGYHLWOf9E1WtC0iqqy7wjvRrNZV1Ea9WbJtCP7VOSbgH2UI0n7bJ9QNJdwJTtyc57vy7pIDANfNb2S73aTThFtNmAZmHa3g3snvXanV2PDXymc9SScIpoK4Pr361bcgmniFZLOEVEiQr+cV3CKaLNEk4RUZyFTcJccgmniBbLYnMRUabcrYuIEik9p4goTr2fpgxNwimitZQB8YgoVHpOEVGkmWEXML+EU0RbFT7Pqe96TpLWSNrbtWvCrUtRWEQ0T653DEOdntMp4DbbT0k6H3hS0sPdOytExIgqeMypb8/J9ou2n+o8fh04RLWgeUREYxY05iRpLXAF8MQc7+0AdgCMc84ASouIpi2LSZiSzgO+CXza9k9mv9/ZJmYCYKUuLPiPHBFAZ2+ocgfEa4WTpBVUwfQN299qtqSIWDIFdyP6hpMkAfcAh2zf3XxJEbFUSr6sq7M11GbgE8DVkp7pHDc0XFdELIXBbA3ViL49J9vfp+SFhiPi3Su455QZ4hEtNcwJlnUknCLabNTv1kXE8pSeU0SUKeEUEcXJmFNEFCvhFBElUsGLzdWZhBkRseTSc4pos1zWRURxMiAeEcVKOEVEkRJOEVEakbt1EVGimjuv1BmXkrRF0nOSDku6vcd5vyHJkjb0azPhFNFmA1jPSdIYsBO4HlgPbJe0fo7zzgduZY49COaScIpos8EsNrcROGz7iO0TwP3AtjnO+1Pgi8CbdUrLmNOI+cj7Lh92CbGMLGAqwSpJU13PJzqbmkC1VdzzXe8dBTa943ukDwFrbP+zpM/W+cKEU0Sb1Q+n47b7jhPNRdIZwN3AJxfyuYRTRFt5YHfrXgDWdD1f3XnttPOBy4BHqv1S+DlgUtJW2929sXdIOEW02WDmOe0D1km6lCqUbgQ+/vZX2K8Bq04/l/QI8Ie9ggkyIB7RaoOYSmD7FHALsAc4BDxg+4CkuyRtfbe1pecU0WYDmiFuezewe9Zrd85z7lV12kw4RbTVEPekqyPhFNFSIqsSREShEk4RUaaEU0QUKeEUEcXJSpgRUayEU0SUqOTF5hJOES2Wy7qIKE8mYUZEsRJOEVGakZ8hLmkceBQ4u3P+g7Y/33RhEdE8zZSbTnV6Tm8BV9t+Q9IK4PuSvm378YZri4gmjfqYk20Db3SerugcBf+RIqKuki/rai02J2lM0jPAMeBh27W2domIwg1m95VG1Aon29O2L6daG3ijpMtmnyNph6QpSVMneWvQdUZEAwa1qWYTFrRMr+1Xgb3Aljnem7C9wfaGFZw9qPoiokmj3HOSdJGkCzqP3wNcB/yw6cIiomGd3VfqHMNQ527dxcBXO1sOn0G1ePlDzZYVEU0b+XlOtvcDVyxBLRGx1FxuOmWGeESLjXTPKSKWqVGfhBkRy1fWc4qIIiWcIqI8JgPiEVGmDIhHRJkSThFRmpGfhBkRy5Q98ovNRcRyVW42JZwi2iyXdRFRHgO5rIuIIpWbTQtbbC4ilpdBrYQpaYuk5yQdlnT7HO9/RtJBSfslfVfSB/q1mXCKaDHNuNbRs41qrbedwPXAemC7pPWzTnsa2GD7l4EHgT/rV1vCKaKt6i7R27/ntBE4bPuI7RPA/cC2d3yVvdf2TztPH6faj6CnjDlFtFQ1CbP2oNMqSVNdzydsT3QeXwI83/XeUWBTj7ZuBr7d7wsTThFtVn9VguO2Nyz26yTdBGwAPtzv3IRTRIstoOfUywvAmq7nqzuvvfO7pGuBPwI+bLvv/nEZc4poq8GNOe0D1km6VNJZwI3AZPcJkq4A/gbYavtYnfLSc4porcH8ts72KUm3AHuAMWCX7QOS7gKmbE8Cfw6cB/yjJIAf2d7aq92EU0SbDWixOdu7gd2zXruz6/G1C20z4RTRVs4yvRFRqizTGxFFKjebEk4RbaaZcq/rEk4RbWUWMglzySWcIlpKeFCTMBuRcIpos4RTRBQp4RQRxcmYU0SUKnfrIqJAzmVdRBTIJJwiolDlXtXVX89J0pikpyU91GRBEbF0ZNc6hmEhPadbgUPAyoZqiYilVvBlXa2ek6TVwEeBrzRbTkQsGRumZ+odQ1C35/Ql4HPA+fOdIGkHsANgnHMWX1lENG+Ue06SPgYcs/1kr/NsT9jeYHvDCs4eWIER0SC73jEEdXpOm4Gtkm4AxoGVkr5u+6ZmS4uIRhkYwBriTenbc7J9h+3VttdS7arwvQRTxHJg8Ey9YwgyzymirczQBrvrWFA42X4EeKSRSiJi6RU8IJ6eU0SbJZwiojz54W9ElMhAlkyJiCKl5xQR5fHyuVsXEcuIwUOaw1RHwimizQqeIZ5wimizjDlFRHHs3K2LiEKl5xQR5TGenh52EfNKOEW0VeFLpiScItqs4KkEtXdfiYjlxYBnXOvoR9IWSc9JOizp9jneP1vSP3Tef0LS2n5tJpwi2sqDWWxO0hiwE7geWA9sl7R+1mk3A6/Y/nngL4Ev9isv4RTRYp6ernX0sRE4bPuI7RPA/cC2WedsA77aefwgcI0k9Wq0kTGn13nl+Hf84P8MuNlVwPEBt9mkUap3lGqF0aq3qVo/sNgGXueVPd/xg6tqnj4uaarr+YTtic7jS4Dnu947Cmya9fm3z7F9StJrwM/Q4++mkXCyfdGg25Q0ZXvDoNttyijVO0q1wmjVW3KttrcMu4ZeclkXEYv1ArCm6/nqzmtzniPpTOC9wEu9Gk04RcRi7QPWSbpU0llUuzRNzjpnEvidzuPfpNrFqedtwFGa5zTR/5SijFK9o1QrjFa9o1Tru9IZQ7oF2AOMAbtsH5B0FzBlexK4B/iapMPAy1QB1pP6hFdExFDksi4iipRwiogijUQ49ZsaXxJJuyQdk/TssGvpR9IaSXslHZR0QNKtw65pPpLGJf1A0r93av2TYddUh6QxSU9LemjYtYya4sOp5tT4ktwLFD1/pMsp4Dbb64ErgU8V/Hf7FnC17V8BLge2SLpyyDXVcStwaNhFjKLiw4l6U+OLYftRqrsRxbP9ou2nOo9fp/qf6JLhVjU3V97oPF3ROYq+myNpNfBR4CvDrmUUjUI4zTU1vsj/gUZZ51fiVwBPDLeS+XUukZ4BjgEP2y621o4vAZ8Dyl2XpGCjEE7RMEnnAd8EPm37J8OuZz62p21fTjUDeaOky4Zd03wkfQw4ZvvJYdcyqkYhnOpMjY93SdIKqmD6hu1vDbueOmy/Cuyl7LG9zcBWSf9NNRRxtaSvD7ek0TIK4VRnany8C50lK+4BDtm+e9j19CLpIkkXdB6/B7gO+OFwq5qf7Ttsr7a9luq/2e/ZvmnIZY2U4sPJ9ing9NT4Q8ADtg8Mt6r5SboPeAz4oKSjkm4edk09bAY+QfWv+jOd44ZhFzWPi4G9kvZT/YP1sO3cnl/G8vOViChS8T2niGinhFNEFCnhFBFFSjhFRJESThFRpIRTRBQp4RQRRfp/9slH0Kzqm7cAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "foNGhpRikpXN",
"outputId": "50c8769b-f836-45d5-d0e6-45b75e85d54c"
},
"source": [
"pmap = sampler.get_probability_map(subject)\n",
"print(pmap.shape)\n",
"ax = plt.imshow(pmap[0, ..., 0])\n",
"plt.colorbar(ax)\n",
"print(pmap.min())\n",
"print(pmap.max())\n",
"p0, p1 = pmap[tensor == 0].sum(), pmap[tensor == 1].sum()\n",
"assert p0 == p1"
],
"execution_count": 88,
"outputs": [
{
"output_type": "stream",
"text": [
"torch.Size([1, 5, 5, 1])\n",
"tensor(0.0208)\n",
"tensor(0.5000)\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAAD8CAYAAAA11GIZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPp0lEQVR4nO3df8idZ33H8fenWdrY2iJSYTXJbP8oQnBbu2VRKOjWKaYq6R9u0IqygiMbWFanQ1o2Otb9NQfiP2EYtCj4o3PqH5mLhqItImhNWrPONHaG4tZ0QqydUzeaNs/z3R/nVE+z58m5nz3nfs51nvv9Kjec+5zT63xbmk+v+7qv+7pSVUhSay6adwGStBLDSVKTDCdJTTKcJDXJcJLUJMNJUpMMJ0nrlmRvkseTnEpy5wqf35bkh0mOj48/nNbmL/VTqqShSLIFOAC8CTgNHE1yqKoeO++rf19Vt3dt156TpPXaA5yqqieq6jngPuDm9TbaS8/p4lxS27isj6YlAc/y3zxXZ7OeNt78O5fVj55Z6vTdhx89ewJ4duKtg1V1cPx6O/DkxGengdeu0Mzbk7we+FfgT6vqyRW+83O9hNM2LuO1+d0+mpYEPFRfWXcbP3pmiW8d+ZVO391y1feerard6/i5fwQ+U1Vnk/wR8Angxgv9DV7WSQNVwHLHv6Z4Ctg5cb5j/N4vfqvqR1V1dnz6UeA3pzXqgLg0UEXxfHW7rJviKHBtkmsYhdItwDsmv5Dkqqr6wfh0H3ByWqOGkzRgHXpFU1XVuSS3A0eALcC9VXUiyT3Asao6BPxJkn3AOeAZ4LZp7RpO0kAVxdKMlkyqqsPA4fPeu3vi9V3AXWtp03CSBmyZdtdzM5ykgSpgyXCS1CJ7TpKaU8DzDS/TbThJA1WUl3WSGlSw1G42GU7SUI1miLfLcJIGKyyxrmeHe2U4SQM1GhA3nCQ1ZjTPyXCS1KBle06SWmPPSVKTirDU8JJunSqbtrOCpMW0XOl0zMPUntMadlaQtECK8FxtmXcZq+rSc+plZwVJ8zWahHlRp2Meuow5ddpZIcl+YD/ANi6dSXGS+jWIAfHxNjEHAa7Iyxt+YkcSQFVYqnYHxLuE09SdFSQtpuUF7zlN3VlB0uIZDYi3O5toamWr7azQe2WSevXCgHirOsXmSjsrSFp8Sz6+Iqk1rc8QN5ykAVte8Lt1kjah0YO/hpOkxhTh+YYfXzGcpIGqYuEnYUralLLwkzAlbUKFPSdJjXJAXFJzivktJNeF4SQN1GhrqHYjoN3KJPXMTTUlNahwhrikRtlzktScqthzktSe0YB4u4+vtBubkno2WkO8yzG1pY57WyZ5e5JKsntam/acpIEaDYivf8yp696WSS4H7gAe6tKuPSdpwJa4qNMxRde9Lf8a+Bvg2S61GU7SQL0wQ7zjduRXJjk2ceyfaGqlvS23T/5Wkt8AdlbVP3Wtz8s6acDWsMHB01U1dZxoJUkuAj4E3LaWv89wkgaqCp5fnsnF07S9LS8HXgM8mATgl4FDSfZV1bHVGjWcpIEaXdbNJJwuuLdlVf0XcOUL50keBP7sQsEEhpM0aLOYIb7a3pZJ7gGOVdWh/0+7hpM0ULOaSgAr721ZVXev8t3f7tKm4SQNlo+vSGqUa4hLas7obl27z9YZTtJAuUyvpGZ5WSepObO8W9cHw0kaMO/WSWpOVThnOElqkZd1kprT+pjT1D5dknuTnEnynY0oSNLGWcN6ThuuywXnx4G9PdchaYOtcbG5DTf1sq6qvpbk6v5LkbTRnOckqTlVcG42i831YmbhNF5TeD/ANi6dVbOSetTygPjMwqmqDgIHAa7Iy2tW7Urqh8/WSWpWNRxOXaYSfAb4BvDqJKeTvLv/siRthGXS6ZiHLnfrbt2IQiRtrKqBjDlJWjRhaQh36yQtnpbHnAwnaaBaf7bOcJKGqkbjTq0ynKQB8/EVSc0pB8QltcrLOklN8m6dpOZUGU6SGuVUAklNcsxJUnOKsOzdOkktarjjZDhJg+WAuKRmNdx1MpykAbPnpJk58h/H513Cmrz5ldfNuwStooDl5XbDqd2hekn9KqDS7Zgiyd4kjyc5leTOFT7/4yT/kuR4kq8n2TWtTcNJGrCqbseFJNkCHABuAnYBt64QPp+uql+tquuADwIfmlab4SQNWXU8LmwPcKqqnqiq54D7gJtf9DNVP5k4vaxLq445SYOVtQyIX5nk2MT5wfFelQDbgScnPjsNvPb//FryHuB9wMXAjdN+0HCShqz7VIKnq2r3un6q6gBwIMk7gL8A/uBC3zecpKEqqNncrXsK2DlxvmP83mruA/5uWqOOOUmDlo7HBR0Frk1yTZKLgVuAQy/6leTaidO3At+b1qg9J2nIZjBDvKrOJbkdOAJsAe6tqhNJ7gGOVdUh4PYkbwSeB/6TKZd0YDhJwzajx1eq6jBw+Lz37p54fcda2zScpKF6YRJmowwnacBcbE5Smxp+ts5wkgYs9pwkNafboylzYzhJg9VtxYF5MZykIbPnJKlJy/MuYHWGkzRUjc9zmvpsXZKdSR5I8liSE0nWPNNTUptS3Y556NJzOge8v6oeSXI58HCS+6vqsZ5rk9S3hsecpvacquoHVfXI+PVPgZOMFpeSpN6sacwpydXA9cBDK3y2H9gPsI1LZ1CapL5tikmYSV4KfB5473nrAQMwXrLzIMAVeXnD/8iSgPHeUO0OiHcKpyRbGQXTp6rqC/2WJGnDNNyNmBpOSQJ8DDhZVVO3c5G0OFq+rOuyTO8NwLuAG8cb4h1P8pae65K0EWazNVQvpvacqurrdFhEWNICarjn5AxxaaDmOcGyC8NJGrJFv1snaXOy5ySpTYaTpOY45iSpWYaTpBal4cXmukzClKQNZ89JGjIv6yQ1xwFxSc0ynCQ1yXCS1JrQ9t06w0kaKsecJDXLcJLUJMNJs/LmV1437xK0iXhZJ6lNDYeTj69IQ1Wju3VdjmmS7E3yeJJTSe5c4fP3JXksyaNJvpLkVdPaNJykIZvBBgdJtgAHgJuAXcCtSXad97VvA7ur6teAzwEfnFaa4SQN2AvriE87ptgDnKqqJ6rqOeA+4ObJL1TVA1X1P+PTbwI7pjVqOElD1r3ndGWSYxPH/olWtgNPTpyfHr+3mncDX5pWmgPi0lCtbU+6p6tq93p/Msk7gd3AG6Z913CSBirMbCrBU8DOifMd4/de/HvJG4E/B95QVWenNeplnTRgMxpzOgpcm+SaJBcDtwCHXvQ7yfXAR4B9VXWmS22GkzRkM7hbV1XngNuBI8BJ4LNVdSLJPUn2jb/2t8BLgX9IcjzJoVWa+zkv66Qhm9EkzKo6DBw+7727J16/ca1tGk7SULkqgaRmGU6SWuRic5Ka5GWdpPasbRLmhjOcpCEznCS1ZoYzxHsxNZySbAO+Blwy/v7nquov+y5MUv+y3G46dek5nQVurKqfJdkKfD3Jl6rqmz3XJqlPiz7mVFUF/Gx8unV8NPyPJKmrli/rOj1bl2RLkuPAGeD+qnqo37IkbYgZPFvXl07hVFVLVXUdo6UQ9iR5zfnfSbL/hYWonmfqagiSGjCjVQl6saZVCarqx8ADwN4VPjtYVburavdWLplVfZL6tMg9pySvSPKy8euXAG8Cvtt3YZJ6NsPdV/rQ5W7dVcAnxjssXMRorZYv9luWpL4t/DynqnoUuH4DapG00arddHKGuDRgC91zkrRJLfokTEmbl+s5SWqS4SSpPYUD4pLa5IC4pDYZTpJas/CTMCVtUlULv9icpM2q3WwynKQh87JOUnsK8LJOUpPazSbDSRoyL+skNcm7dZLa46oEklo0moTZbjoZTtKQuSqBpBbZc5LUHsecJLWp7Wfr1rSppqRNpqrbMUWSvUkeT3IqyZ0rfP76JI8kOZfk97qUZjhJQzWjTTXHe1oeAG4CdgG3Jtl13tf+HbgN+HTX8rysk4ZsNgPie4BTVfUEQJL7gJuBx37xM/X98Wed7w/ac5KGrDoecGWSYxPH/olWtgNPTpyfHr+3LvacpAHLcueOzNNVtbvPWs5nOElDVcxqEuZTwM6J8x3j99bFyzppoEKR6nZMcRS4Nsk1SS4GbgEOrbc+w0kashlMJaiqc8DtwBHgJPDZqjqR5J4k+wCS/FaS08DvAx9JcmJaaV7WSUM2o8dXquowcPi89+6eeH2U0eVeZ4aTNFSzG3PqheEkDdga7tZtOMNJGqxuj6bMi+EkDVVhOElqVLtXdd2nEiTZkuTbSb7YZ0GSNs6M5jn1Yi09pzsYzWG4oqdaJG20hi/rOvWckuwA3gp8tN9yJG2YKlha7nbMQdee04eBDwCXr/aF8VPK+wG2cen6K5PUv0XuOSV5G3Cmqh6+0Peq6mBV7a6q3Vu5ZGYFSurRjFbC7EOXntMNwL4kbwG2AVck+WRVvbPf0iT1qoBFXkO8qu6qqh1VdTWjp42/ajBJm0FBLXc75sB5TtJQFXMb7O5iTeFUVQ8CD/ZSiaSN1/CAuD0nacgMJ0nt8cFfSS0qwCVTJDXJnpOk9tTmuVsnaRMpqDnNYerCcJKGrOEZ4oaTNGSOOUlqTpV36yQ1yp6TpPYUtbQ07yJWZThJQ9X4kimGkzRkTiWQ1JoCyp6TpOZU2XOS1KaWB8RTPdxKTPJD4N9m3OyVwNMzbrNPi1TvItUKi1VvX7W+qqpesZ4GknyZUX1dPF1Ve9fze2vVSzj1Icmxqto97zq6WqR6F6lWWKx6F6nW1nTejlySNpLhJKlJixROB+ddwBotUr2LVCssVr2LVGtTFmbMSdKwLFLPSdKAGE6SmrQQ4ZRkb5LHk5xKcue867mQJPcmOZPkO/OuZZokO5M8kOSxJCeS3DHvmlaTZFuSbyX553GtfzXvmrpIsiXJt5N8cd61LJrmwynJFuAAcBOwC7g1ya75VnVBHwc2dLLaOpwD3l9Vu4DXAe9p+N/tWeDGqvp14Dpgb5LXzbmmLu4ATs67iEXUfDgBe4BTVfVEVT0H3AfcPOeaVlVVXwOemXcdXVTVD6rqkfHrnzL6Q7R9vlWtrEZ+Nj7dOj6avpuTZAfwVuCj865lES1COG0Hnpw4P02jf4AWWZKrgeuBh+ZbyerGl0jHgTPA/VXVbK1jHwY+ALT7dG3DFiGc1LMkLwU+D7y3qn4y73pWU1VLVXUdsAPYk+Q1865pNUneBpypqofnXcuiWoRwegrYOXG+Y/yeZiDJVkbB9Kmq+sK86+miqn4MPEDbY3s3APuSfJ/RUMSNST4535IWyyKE01Hg2iTXJLkYuAU4NOeaNoUkAT4GnKyqD827ngtJ8ookLxu/fgnwJuC7861qdVV1V1XtqKqrGf03+9Wqeuecy1oozYdTVZ0DbgeOMBqw/WxVnZhvVatL8hngG8Crk5xO8u5513QBNwDvYvR/9ePj4y3zLmoVVwEPJHmU0f+w7q8qb89vYj6+IqlJzfecJA2T4SSpSYaTpCYZTpKaZDhJapLhJKlJhpOkJv0vOse+S26N7BcAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 496
},
"id": "T91VpRClktEP",
"outputId": "c6a43c3c-9ec3-4feb-9e25-e33a7bff4c30"
},
"source": [
"noborder = pmap.clone()[0]\n",
"sampler.clear_probability_borders(noborder, torch.as_tensor(patch_size))\n",
"ax = plt.imshow(noborder[..., 0])\n",
"plt.colorbar(ax)\n",
"print(noborder.min())\n",
"print(noborder.max())\n",
"p0, p1 = noborder[tensor[0] == 0].sum(), noborder[tensor[0] == 1].sum()\n",
"assert p0 == p1, (p0, p1)"
],
"execution_count": 89,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor(0.)\n",
"tensor(0.5000)\n"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "AssertionError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-89-7d8d55210a52>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnoborder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mp0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnoborder\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoborder\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mp0\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mp1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mp0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m: (tensor(0.3125), tensor(0.5000))"
]
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAAD8CAYAAAA11GIZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQZ0lEQVR4nO3db4xc1X3G8e/DxsbhjxtRIxVhNyDVSmTRFlrXjoREWgdUQyK7UlIJIqIgUbmVYpU0aSNQK6q6r5JKNG/8IiuwgpoEJyV5saWOLJIYIaTE8QIujTE0FkpjIyTXQAi0wvbuPn0xAx22u3vvsnP3ntn7fNCV5s5cn/mB4OHcc8+cI9tERJTmgrYLiIiYS8IpIoqUcIqIIiWcIqJICaeIKFLCKSKKlHCKiCWTtF3S85JOSLp7js/vkPRfko72jz+pavM9zZQaEV0haQzYC9wEnAKOSJqw/eysS79pe3fddtNzioil2gKcsP2C7XPAfmDnUhttpOe0Whd6DRc30XREAG/y35zzWS2ljT/8g4v98ivTta598pmzx4A3B94atz3ef30lcHLgs1PA1jma+bikG4D/AP7C9sk5rnlbI+G0hovZqo800XREAIf9/SW38fIr0/z44K/Xunbsip++aXvzEr7uX4CHbJ+V9KfAg8C2hf5AbusiOsrATM2/KrwIbBg4X99/7/++y37Z9tn+6f3A71Y1mgHxiI4y5rzr3dZVOAJslHQ1vVC6Ffjk4AWSrrD9Uv90B3C8qtGEU0SH1egVVbI9JWk3cBAYA/bZPiZpDzBpewL4c0k7gCngFeCOqnYTThEdZcz0kJZMsn0AODDrvXsHXt8D3LOYNhNOER02Q7nruSWcIjrKwHTCKSJKlJ5TRBTHwPmCl+lOOEV0lHFu6yKiQIbpcrMp4RTRVb0Z4uVKOEV0lphmSb8dblTCKaKjegPiCaeIKExvnlPCKSIKNJOeU0SUJj2niCiSEdMFL+lWq7KqnRUiYjTNWLWONlT2nBaxs0JEjBAjznms7TLmVafn1MjOChHRrt4kzAtqHW2oM+ZUa2cFSbuAXQBruGgoxUVEszoxIN7fJmYcYK0uK/gXOxEBYItplzsgXiecKndWiIjRNDPiPafKnRUiYvT0BsTLnU1UWdl8Oys0XllENOqtAfFS1YrNuXZWiIjRN52fr0REaUqfIZ5wiuiwmRF/WhcRK1Dvh78Jp4gojBHnC/75SsIpoqNsRn4SZkSsSBr5SZgRsQKZ9JwiolAZEI+I4pj2FpKrI+EU0VG9raHKjYByK4uIhmVTzYgokMkM8YgoVMk9p3JjMyIaZYsZX1DrqFJ3hyZJH5dkSZur2kzPKaKjegPiS//5St0dmiRdCtwFHK7TbnpOEZ3VW0O8zlGh7g5Nfw98EXizTnWd7zmNrV3bdgkRi6Y3lt6v6A2I1x5zWidpcuB8vL+pCdTYoUnS7wAbbP+rpL+q84WdD6eILlvEDPEztivHieYi6QLgPuCOxfy5hFNERw1xhnjVDk2XAtcAj0kC+DVgQtIO24O9sXdIOEV02JA2OFhwhybbrwHr3jqX9BjwlwsFEyScIjrLhvMzQxi7mmeHJkl7gEnbE++m3YRTREf1buuG88B+rh2abN87z7W/X6fNhFNEh5U8QzzhFNFRi5xKsOwSThGdNbzbuiYknCI6LGuIR0Rxek/rsjVURBQmy/RGRLFyWxcRxcnTuogoVp7WRURxbDGVcIqIEuW2LiKKU/qYU2WfTtI+Sacl/WQ5CoqI5TNj1TraUOeG86vA9obriIhl9tY8p1LDqfK2zvbjkq5qvpSIWG6Z5xQRxbFhagiLzTVlaOEkaRewC2ANFw2r2YhoUMkD4kMLp/42MeMAa3WZh9VuRDQjv62LiGK54HCqM5XgIeCHwAcknZJ0Z/NlRcRymEG1jjbUeVp323IUEhHLy+7ImFNEjBox3YWndRExekoec0o4RXRU6b+tSzhFdJV7406lSjhFdFh+vhIRxXEGxCOiVLmti4gi5WldRBTHTjhFRKEylSAiipQxp4gojhEzeVoXESUquONUa4ODiFiJ+gPidY4qkrZLel7SCUl3z/H5n0n6d0lHJT0haVNVmwmniC5zzWMBksaAvcDNwCbgtjnC5xu2f9P2tcCXgPuqSks4RXTYkHpOW4ATtl+wfQ7YD+x85/f4lwOnF1PjjjJjTiPmwHOPt13CotzywRvaLiHmYWBmpvZUgnWSJgfOx/v7BgBcCZwc+OwUsHV2A5I+A3wOWA1sq/rChFNEVxmoP8/pjO3NS/o6ey+wV9Ingb8BPr3Q9bmti+gwu95R4UVgw8D5+v5789kP/FFVowmniC4bwoA4cATYKOlqSauBW4GJwQskbRw4/Sjw06pGc1sX0Vn1pglUsT0laTdwEBgD9tk+JmkPMGl7Atgt6UbgPPAqFbd0kHCK6LYhzcK0fQA4MOu9ewde37XYNhNOEV1lcP2ndcsu4RTRaQmniChRwT+uSzhFdFnCKSKKs7hJmMsu4RTRYVlsLiLKlKd1EVEipecUEcWp99OU1iScIjpLGRCPiEKl5xQRRZppu4D5JZwiuqrweU6V6zlJ2iDpkKRnJR2TtOhfF0dEmeR6Rxvq9JymgM/bfkrSpcCTkh61/WzDtUVE0woec6rsOdl+yfZT/devA8fpLWgeEdGYRY05SboKuA44PMdnu4BdAGu4aAilRUTTVsQkTEmXAN8GPjtrDyoA+tvEjAOs1WUF/y1HBNDfG6rcAfFa4SRpFb1g+rrt7zRbUkQsm4K7EZXhJEnAA8Bx25VbCEfE6Cj5tq7O1lDXA58Ctkk62j9uabiuiFgOw9kaqhGVPSfbT1DyQsMR8e4V3HPKDPGIjmpzgmUdCaeILhv1p3URsTKl5xQRZUo4RURxMuYUEcVKOEVEiVTwYnN1JmFGRCy79Jwiuiy3dRFRnAyIR0SxEk4RUaSEU0SURuRpXUSUqObOK3XGpSRtl/S8pBOS7p7j88/1d3B6RtL3Jb2/qs2EU0SXDWE9J0ljwF7gZmATcJukTbMuexrYbPu3gIeBL1WVlnCK6LLhLDa3BThh+wXb54D9wM53fI19yPb/9E9/BKyvajRjTiPmlg/e0HYJsYIsYirBOkmTA+fj/U1NoLdV3MmBz04BWxdo607gu1VfmHCK6LL64XTG9ualfp2k24HNwIerrk04RXSVh/a07kVgw8D5+v577yDpRuCvgQ/bPlvVaMacIrpsOGNOR4CNkq6WtBq4FZgYvEDSdcBXgB22T9cpLT2niA4bxs9XbE9J2g0cBMaAfbaPSdoDTNqeAP4BuAT4595uc/zc9o6F2k04RXTZkGaI2z4AHJj13r0Dr29cbJsJp4iuanFPujoSThEdJbIqQUQUKuEUEWVKOEVEkRJOEVGcrIQZEcVKOEVEiUpebC7hFNFhua2LiPJkEmZEFCvhFBGlGfkZ4pLWAI8DF/avf9j23zZdWEQ0TzPlplOdntNZYJvtNyStAp6Q9F3bP2q4toho0qiPOdk28Eb/dFX/KPhvKSLqKvm2rtZKmJLGJB0FTgOP2j7cbFkRsSyGsxJmI2qFk+1p29fSWxt4i6RrZl8jaZekSUmT56lcHjgiCjCsTTWbsKg1xG3/AjgEbJ/js3Hbm21vXsWFw6ovIpo0yj0nSZdLel//9XuBm4Dnmi4sIhrW332lztGGOk/rrgAe7G85fAHwLduPNFtWRDRt5Oc52X4GuG4ZaomI5eZy0ykzxCM6bKR7ThGxQo36JMyIWLmynlNEFCnhFBHlMRkQj4gyZUA8IsqUcIqI0oz8JMyIWKHskV9sLiJWqnKzKeEU0WW5rYuI8hjIbV1EFKncbFrcYnMRsbIMayVMSdslPS/phKS75/j8BklPSZqS9Ik6tSWcIjpMM651LNhGb623vcDNwCbgNkmbZl32c+AO4Bt1a8ttXURXDW9Vgi3ACdsvAEjaD+wEnn37q+yf9T+r/Wu+hFNER/UmYdZOp3WSJgfOx22P919fCZwc+OwUsHWp9SWcIrqs/qoEZ2xvbrCS/yfhFNFhi+g5LeRFYMPA+fr+e0uSAfGIrqq7LVR1fh0BNkq6WtJq4FZgYqnlJZwiOqvek7qqp3W2p4DdwEHgOL0dmo5J2iNpB4Ck35N0Cvhj4CuSjlVVl9u6iC4b0mJztg8AB2a9d+/A6yP0bvdqSzhFdJWzTG9ElCrL9EZEkcrNpoRTRJdpptz7uoRTRFeZxUzCXHYJp4iOEh7WJMxGJJwiuizhFBFFSjhFRHEy5hQRpcrTuogokHNbFxEFMgmniChUuXd19ZdMkTQm6WlJjzRZUEQsH9m1jjYspud0F721WtY2VEtELLeCb+tq9ZwkrQc+CtzfbDkRsWxsmJ6pd7Sgbs/py8AXgEvnu0DSLmAXwBouWnplEdG8Ue45SfoYcNr2kwtdZ3vc9mbbm1dx4dAKjIgG2fWOFtTpOV0P7JB0C7AGWCvpa7Zvb7a0iGiUgYr1wdtU2XOyfY/t9bavorerwg8STBErgcEz9Y4WZJ5TRFeZ1ga761hUONl+DHiskUoiYvkVPCCenlNElyWcIqI8+eFvRJTIQJZMiYgipecUEeXxynlaFxEriMEtzWGqI+EU0WUFzxBPOEV0WcacIqI4dp7WRUSh0nOKiPIYT0+3XcS8Ek4RXVX4kikJp4guK3gqQe3dVyJiZTHgGdc6qkjaLul5SSck3T3H5xdK+mb/88OSrqpqM+EU0VUezmJzksaAvcDNwCbgNkmbZl12J/Cq7d8A/hH4YlV5CaeIDvP0dK2jwhbghO0XbJ8D9gM7Z12zE3iw//ph4COStFCjjYw5vc6rZ77nh/9zyM2uA84MuU14begtvqWZepsxSrXCaNXbVK3vX2oDr/Pqwe/54XU1L18jaXLgfNz2eP/1lcDJgc9OAVtn/fm3r7E9Jek14FdZ4J9NI+Fk+/Jhtylp0vbmYbfblFGqd5RqhdGqt+RabW9vu4aF5LYuIpbqRWDDwPn6/ntzXiPpPcCvAC8v1GjCKSKW6giwUdLVklbT26VpYtY1E8Cn+68/QW8XpwUfA47SPKfx6kuKMkr1jlKtMFr1jlKt70p/DGk3cBAYA/bZPiZpDzBpewJ4APgnSSeAV+gF2IJUEV4REa3IbV1EFCnhFBFFGolwqpoaXxJJ+ySdlvSTtmupImmDpEOSnpV0TNJdbdc0H0lrJP1Y0r/1a/27tmuqQ9KYpKclPdJ2LaOm+HCqOTW+JF8Fip4/MmAK+LztTcCHgM8U/M/2LLDN9m8D1wLbJX2o5ZrquAs43nYRo6j4cKLe1Phi2H6c3tOI4tl+yfZT/dev0/uP6Mp2q5qbe97on67qH0U/zZG0HvgocH/btYyiUQinuabGF/kf0Cjr/0r8OuBwu5XMr3+LdBQ4DTxqu9ha+74MfAEod12Sgo1COEXDJF0CfBv4rO1ftl3PfGxP276W3gzkLZKuabum+Uj6GHDa9pNt1zKqRiGc6kyNj3dJ0ip6wfR1299pu546bP8COETZY3vXAzsk/YzeUMQ2SV9rt6TRMgrhVGdqfLwL/SUrHgCO276v7XoWIulySe/rv34vcBPwXLtVzc/2PbbX276K3r+zP7B9e8tljZTiw8n2FPDW1PjjwLdsH2u3qvlJegj4IfABSack3dl2TQu4HvgUvf+rH+0ft7Rd1DyuAA5Jeobe/7AetZ3H8ytYfr4SEUUqvucUEd2UcIqIIiWcIqJICaeIKFLCKSKKlHCKiCIlnCKiSP8LWXxR538JAYkAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment