Created
July 19, 2018 18:41
-
-
Save shoyer/c0f1ddf409667650a076c058f9a17276 to your computer and use it in GitHub Desktop.
resize non-integer local mean.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "resize non-integer local mean.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"[View in Colaboratory](https://colab.research.google.com/gist/shoyer/c0f1ddf409667650a076c058f9a17276/resize-non-integer-local-mean.ipynb)" | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "Yl-_CK2J0APU", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"# Copyright 2018 Google LLC\n", | |
"#\n", | |
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | |
"# you may not use this file except in compliance with the License.\n", | |
"# You may obtain a copy of the License at\n", | |
"#\n", | |
"# https://www.apache.org/licenses/LICENSE-2.0\n", | |
"#\n", | |
"# Unless required by applicable law or agreed to in writing, software\n", | |
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | |
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | |
"# See the License for the specific language governing permissions and\n", | |
"# limitations under the License.\n", | |
"\n", | |
"import numpy as np\n", | |
"from typing import Tuple, Iterable\n", | |
"\n", | |
"\n", | |
"def _reflect_breaks(size: int) -> np.ndarray:\n", | |
" \"\"\"Calculate cell boundaries with reflecting boundary conditions.\"\"\"\n", | |
" result = np.concatenate([[0], 0.5 + np.arange(size - 1), [size - 1]])\n", | |
" assert len(result) == size + 1\n", | |
" return result\n", | |
"\n", | |
"def _interval_overlap(first_breaks: np.ndarray,\n", | |
" second_breaks: np.ndarray) -> np.ndarray:\n", | |
" \"\"\"Return the overlap distance between all pairs of intervals.\n", | |
"\n", | |
" Args:\n", | |
" first_breaks: breaks between entries in the first set of intervals, with\n", | |
" shape (N+1,). Must be a non-decreasing sequence.\n", | |
" second_breaks: breaks between entries in the second set of intervals, with\n", | |
" shape (M+1,). Must be a non-decreasing sequence.\n", | |
"\n", | |
" Returns:\n", | |
" Array with shape (N, M) giving the size of the overlapping region between\n", | |
" each pair of intervals.\n", | |
" \"\"\"\n", | |
" first_upper = first_breaks[1:]\n", | |
" second_upper = second_breaks[1:]\n", | |
" upper = np.minimum(first_upper[:, np.newaxis], second_upper[np.newaxis, :])\n", | |
"\n", | |
" first_lower = first_breaks[:-1]\n", | |
" second_lower = second_breaks[:-1]\n", | |
" lower = np.maximum(first_lower[:, np.newaxis], second_lower[np.newaxis, :])\n", | |
"\n", | |
" return np.maximum(upper - lower, 0)\n", | |
"\n", | |
"def _resize_weights(\n", | |
" old_size: int, new_size: int, reflect: bool = False) -> np.ndarray:\n", | |
" \"\"\"Create a weight matrix for resizing with the local mean along an axis.\n", | |
"\n", | |
" Args:\n", | |
" old_size: old size.\n", | |
" new_size: new size.\n", | |
" reflect: whether or not there are reflecting boundary conditions.\n", | |
"\n", | |
" Returns:\n", | |
" NumPy array with shape (new_size, old_size). Rows sum to 1.\n", | |
" \"\"\"\n", | |
" if not reflect:\n", | |
" old_breaks = np.linspace(0, old_size, num=old_size + 1)\n", | |
" new_breaks = np.linspace(0, old_size, num=new_size + 1)\n", | |
" else:\n", | |
" old_breaks = _reflect_breaks(old_size)\n", | |
" new_breaks = (old_size - 1) / (new_size - 1) * _reflect_breaks(new_size)\n", | |
"\n", | |
" weights = _interval_overlap(new_breaks, old_breaks)\n", | |
" weights /= np.sum(weights, axis=1, keepdims=True)\n", | |
" assert weights.shape == (new_size, old_size)\n", | |
" return weights\n", | |
"\n", | |
"def resize(array: np.ndarray,\n", | |
" shape: Tuple[int, ...],\n", | |
" reflect_axes: Iterable[int] = ()) -> np.ndarray:\n", | |
" \"\"\"Resize an array with the local mean / bilinear scaling.\n", | |
"\n", | |
" Works for both upsampling and downsampling in a fashion equivalent to\n", | |
" block_mean and zoom, but allows for resizing by non-integer multiples. Prefer\n", | |
" block_mean and zoom when possible, as this implementation is probably slower.\n", | |
"\n", | |
" Args:\n", | |
" array: array to resize.\n", | |
" shape: shape of the resized array.\n", | |
" reflect_axes: iterable of axis numbers with reflecting boundary conditions,\n", | |
" mirrored over the center of the first and last cell.\n", | |
"\n", | |
" Returns:\n", | |
" Array resized to shape.\n", | |
"\n", | |
" Raises:\n", | |
" ValueError: if any values in reflect_axes fall outside the interval\n", | |
" [-array.ndim, array.ndim).\n", | |
" \"\"\"\n", | |
" reflect_axes_set = set()\n", | |
" for axis in reflect_axes:\n", | |
" if not -array.ndim <= axis < array.ndim:\n", | |
" raise ValueError('invalid axis: {}'.format(axis))\n", | |
" reflect_axes_set.add(axis % array.ndim)\n", | |
"\n", | |
" output = array\n", | |
" for axis, (old_size, new_size) in enumerate(zip(array.shape, shape)):\n", | |
" reflect = axis in reflect_axes_set\n", | |
" weights = _resize_weights(old_size, new_size, reflect=reflect)\n", | |
" product = np.tensordot(output, weights, [[axis], [-1]])\n", | |
" output = np.moveaxis(product, -1, axis)\n", | |
" return output" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"id": "WCOBQufI0g_U", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 125 | |
}, | |
"outputId": "fd76e2ca-044c-43a5-bec5-15f24ed962cf" | |
}, | |
"cell_type": "code", | |
"source": [ | |
"import unittest\n", | |
"import skimage.measure\n", | |
"\n", | |
"class ResizeLocalMeanTest(unittest.TestCase):\n", | |
"\n", | |
" def test_reflect_breaks(self):\n", | |
" expected = np.array([0, 0.5, 1.5, 2.5, 3])\n", | |
" actual = _reflect_breaks(4)\n", | |
" np.testing.assert_array_equal(expected, actual)\n", | |
"\n", | |
" def test_interval_overlap(self):\n", | |
" actual = _interval_overlap(\n", | |
" np.array([0, 3, 7, 10]), np.array([0, 1, 4, 6, 9, 10]))\n", | |
" expected = np.array([[1, 2, 0, 0, 0], [0, 1, 2, 1, 0], [0, 0, 0, 2, 1]])\n", | |
" np.testing.assert_array_equal(expected, actual)\n", | |
"\n", | |
" def test_resize_weights_downscale(self):\n", | |
" expected = np.array([[0.4, 0.4, 0.2, 0, 0], [0, 0, 0.2, 0.4, 0.4]])\n", | |
" actual = _resize_weights(5, 2)\n", | |
" np.testing.assert_allclose(expected, actual)\n", | |
"\n", | |
" def test_resize_weights_upscale(self):\n", | |
" expected = np.array([[1, 0], [1, 0], [0.5, 0.5], [0, 1], [0, 1]])\n", | |
" actual = _resize_weights(2, 5)\n", | |
" np.testing.assert_allclose(expected, actual)\n", | |
"\n", | |
" def test_resize_weights_downscale_reflect(self):\n", | |
" expected = np.array([[0.5, 0.5, 0, 0, 0],\n", | |
" [0, 0.25, 0.5, 0.25, 0],\n", | |
" [0, 0, 0, 0.5, 0.5]])\n", | |
" actual = _resize_weights(5, 3, reflect=True)\n", | |
" np.testing.assert_allclose(expected, actual)\n", | |
"\n", | |
" def test_resize_identity(self):\n", | |
" x = np.random.RandomState(0).randn(10)\n", | |
" actual = resize(x, (10,))\n", | |
" np.testing.assert_allclose(x, actual)\n", | |
"\n", | |
" def test_resize_identity_reflect(self):\n", | |
" x = np.random.RandomState(0).randn(10)\n", | |
" actual = resize(x, (10,), reflect_axes={0})\n", | |
" np.testing.assert_allclose(x, actual)\n", | |
"\n", | |
" def test_resize_downscale(self):\n", | |
" x = np.random.RandomState(0).randn(10, 15)\n", | |
" expected = skimage.measure.block_reduce(x, (5, 5), func=np.mean)\n", | |
" actual = resize(x, (2, 3))\n", | |
" np.testing.assert_allclose(expected, actual)\n", | |
"\n", | |
" def test_resize_upscale(self):\n", | |
" x = np.random.RandomState(0).randn(5, 5)\n", | |
" expected = np.repeat(np.repeat(x, 2, axis=0), 3, axis=1)\n", | |
" actual = resize(x, (10, 15))\n", | |
" np.testing.assert_allclose(expected, actual)\n", | |
"\n", | |
" def test_resize_3d(self):\n", | |
" x = np.random.RandomState(0).randn(5, 6, 7)\n", | |
" new_shape = (1, 2, 3)\n", | |
" actual = resize(x, new_shape)\n", | |
" self.assertEqual(actual.shape, new_shape)\n", | |
"\n", | |
" def test_resize_reflect_axes_invalid(self):\n", | |
" with self.assertRaisesRegex(ValueError, 'invalid axis'):\n", | |
" resize(np.zeros(5), (10,), reflect_axes={1})\n", | |
" \n", | |
"suite = unittest.defaultTestLoader.loadTestsFromTestCase(ResizeLocalMeanTest)\n", | |
"unittest.TextTestRunner().run(suite)\n" | |
], | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"...........\n", | |
"----------------------------------------------------------------------\n", | |
"Ran 11 tests in 0.015s\n", | |
"\n", | |
"OK\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<unittest.runner.TextTestResult run=11 errors=0 failures=0>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"id": "i-3ksd4G02NF", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment