Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created July 19, 2018 18:41
Show Gist options
  • Save shoyer/c0f1ddf409667650a076c058f9a17276 to your computer and use it in GitHub Desktop.
Save shoyer/c0f1ddf409667650a076c058f9a17276 to your computer and use it in GitHub Desktop.
resize non-integer local mean.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "resize non-integer local mean.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/gist/shoyer/c0f1ddf409667650a076c058f9a17276/resize-non-integer-local-mean.ipynb)"
]
},
{
"metadata": {
"id": "Yl-_CK2J0APU",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"# Copyright 2018 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import numpy as np\n",
"from typing import Tuple, Iterable\n",
"\n",
"\n",
"def _reflect_breaks(size: int) -> np.ndarray:\n",
" \"\"\"Calculate cell boundaries with reflecting boundary conditions.\"\"\"\n",
" result = np.concatenate([[0], 0.5 + np.arange(size - 1), [size - 1]])\n",
" assert len(result) == size + 1\n",
" return result\n",
"\n",
"def _interval_overlap(first_breaks: np.ndarray,\n",
" second_breaks: np.ndarray) -> np.ndarray:\n",
" \"\"\"Return the overlap distance between all pairs of intervals.\n",
"\n",
" Args:\n",
" first_breaks: breaks between entries in the first set of intervals, with\n",
" shape (N+1,). Must be a non-decreasing sequence.\n",
" second_breaks: breaks between entries in the second set of intervals, with\n",
" shape (M+1,). Must be a non-decreasing sequence.\n",
"\n",
" Returns:\n",
" Array with shape (N, M) giving the size of the overlapping region between\n",
" each pair of intervals.\n",
" \"\"\"\n",
" first_upper = first_breaks[1:]\n",
" second_upper = second_breaks[1:]\n",
" upper = np.minimum(first_upper[:, np.newaxis], second_upper[np.newaxis, :])\n",
"\n",
" first_lower = first_breaks[:-1]\n",
" second_lower = second_breaks[:-1]\n",
" lower = np.maximum(first_lower[:, np.newaxis], second_lower[np.newaxis, :])\n",
"\n",
" return np.maximum(upper - lower, 0)\n",
"\n",
"def _resize_weights(\n",
" old_size: int, new_size: int, reflect: bool = False) -> np.ndarray:\n",
" \"\"\"Create a weight matrix for resizing with the local mean along an axis.\n",
"\n",
" Args:\n",
" old_size: old size.\n",
" new_size: new size.\n",
" reflect: whether or not there are reflecting boundary conditions.\n",
"\n",
" Returns:\n",
" NumPy array with shape (new_size, old_size). Rows sum to 1.\n",
" \"\"\"\n",
" if not reflect:\n",
" old_breaks = np.linspace(0, old_size, num=old_size + 1)\n",
" new_breaks = np.linspace(0, old_size, num=new_size + 1)\n",
" else:\n",
" old_breaks = _reflect_breaks(old_size)\n",
" new_breaks = (old_size - 1) / (new_size - 1) * _reflect_breaks(new_size)\n",
"\n",
" weights = _interval_overlap(new_breaks, old_breaks)\n",
" weights /= np.sum(weights, axis=1, keepdims=True)\n",
" assert weights.shape == (new_size, old_size)\n",
" return weights\n",
"\n",
"def resize(array: np.ndarray,\n",
" shape: Tuple[int, ...],\n",
" reflect_axes: Iterable[int] = ()) -> np.ndarray:\n",
" \"\"\"Resize an array with the local mean / bilinear scaling.\n",
"\n",
" Works for both upsampling and downsampling in a fashion equivalent to\n",
" block_mean and zoom, but allows for resizing by non-integer multiples. Prefer\n",
" block_mean and zoom when possible, as this implementation is probably slower.\n",
"\n",
" Args:\n",
" array: array to resize.\n",
" shape: shape of the resized array.\n",
" reflect_axes: iterable of axis numbers with reflecting boundary conditions,\n",
" mirrored over the center of the first and last cell.\n",
"\n",
" Returns:\n",
" Array resized to shape.\n",
"\n",
" Raises:\n",
" ValueError: if any values in reflect_axes fall outside the interval\n",
" [-array.ndim, array.ndim).\n",
" \"\"\"\n",
" reflect_axes_set = set()\n",
" for axis in reflect_axes:\n",
" if not -array.ndim <= axis < array.ndim:\n",
" raise ValueError('invalid axis: {}'.format(axis))\n",
" reflect_axes_set.add(axis % array.ndim)\n",
"\n",
" output = array\n",
" for axis, (old_size, new_size) in enumerate(zip(array.shape, shape)):\n",
" reflect = axis in reflect_axes_set\n",
" weights = _resize_weights(old_size, new_size, reflect=reflect)\n",
" product = np.tensordot(output, weights, [[axis], [-1]])\n",
" output = np.moveaxis(product, -1, axis)\n",
" return output"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "WCOBQufI0g_U",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 125
},
"outputId": "fd76e2ca-044c-43a5-bec5-15f24ed962cf"
},
"cell_type": "code",
"source": [
"import unittest\n",
"import skimage.measure\n",
"\n",
"class ResizeLocalMeanTest(unittest.TestCase):\n",
"\n",
" def test_reflect_breaks(self):\n",
" expected = np.array([0, 0.5, 1.5, 2.5, 3])\n",
" actual = _reflect_breaks(4)\n",
" np.testing.assert_array_equal(expected, actual)\n",
"\n",
" def test_interval_overlap(self):\n",
" actual = _interval_overlap(\n",
" np.array([0, 3, 7, 10]), np.array([0, 1, 4, 6, 9, 10]))\n",
" expected = np.array([[1, 2, 0, 0, 0], [0, 1, 2, 1, 0], [0, 0, 0, 2, 1]])\n",
" np.testing.assert_array_equal(expected, actual)\n",
"\n",
" def test_resize_weights_downscale(self):\n",
" expected = np.array([[0.4, 0.4, 0.2, 0, 0], [0, 0, 0.2, 0.4, 0.4]])\n",
" actual = _resize_weights(5, 2)\n",
" np.testing.assert_allclose(expected, actual)\n",
"\n",
" def test_resize_weights_upscale(self):\n",
" expected = np.array([[1, 0], [1, 0], [0.5, 0.5], [0, 1], [0, 1]])\n",
" actual = _resize_weights(2, 5)\n",
" np.testing.assert_allclose(expected, actual)\n",
"\n",
" def test_resize_weights_downscale_reflect(self):\n",
" expected = np.array([[0.5, 0.5, 0, 0, 0],\n",
" [0, 0.25, 0.5, 0.25, 0],\n",
" [0, 0, 0, 0.5, 0.5]])\n",
" actual = _resize_weights(5, 3, reflect=True)\n",
" np.testing.assert_allclose(expected, actual)\n",
"\n",
" def test_resize_identity(self):\n",
" x = np.random.RandomState(0).randn(10)\n",
" actual = resize(x, (10,))\n",
" np.testing.assert_allclose(x, actual)\n",
"\n",
" def test_resize_identity_reflect(self):\n",
" x = np.random.RandomState(0).randn(10)\n",
" actual = resize(x, (10,), reflect_axes={0})\n",
" np.testing.assert_allclose(x, actual)\n",
"\n",
" def test_resize_downscale(self):\n",
" x = np.random.RandomState(0).randn(10, 15)\n",
" expected = skimage.measure.block_reduce(x, (5, 5), func=np.mean)\n",
" actual = resize(x, (2, 3))\n",
" np.testing.assert_allclose(expected, actual)\n",
"\n",
" def test_resize_upscale(self):\n",
" x = np.random.RandomState(0).randn(5, 5)\n",
" expected = np.repeat(np.repeat(x, 2, axis=0), 3, axis=1)\n",
" actual = resize(x, (10, 15))\n",
" np.testing.assert_allclose(expected, actual)\n",
"\n",
" def test_resize_3d(self):\n",
" x = np.random.RandomState(0).randn(5, 6, 7)\n",
" new_shape = (1, 2, 3)\n",
" actual = resize(x, new_shape)\n",
" self.assertEqual(actual.shape, new_shape)\n",
"\n",
" def test_resize_reflect_axes_invalid(self):\n",
" with self.assertRaisesRegex(ValueError, 'invalid axis'):\n",
" resize(np.zeros(5), (10,), reflect_axes={1})\n",
" \n",
"suite = unittest.defaultTestLoader.loadTestsFromTestCase(ResizeLocalMeanTest)\n",
"unittest.TextTestRunner().run(suite)\n"
],
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": [
"...........\n",
"----------------------------------------------------------------------\n",
"Ran 11 tests in 0.015s\n",
"\n",
"OK\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<unittest.runner.TextTestResult run=11 errors=0 failures=0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"metadata": {
"id": "i-3ksd4G02NF",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment