Last active
July 3, 2025 10:35
-
-
Save nmichlo/994bdf5c015c3cbf0b6c379a535fbde5 to your computer and use it in GitHub Desktop.
Torch & NumPy redistributed splitting/chunking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # BASED ON @nmichlo's ISSUE AT: | |
| # https://github.com/pytorch/pytorch/issues/60531 | |
| # NOTE: modern features of torch and numpy make this easier to implement in python | |
| # as modern one liners. | |
| ##### TORCH EXAMPLE ##### | |
| # chunks = 3 | |
| # dim = 0 | |
| # tensor = torch.arange(7) | |
| # dim_size = tensor.size(dim) | |
| # | |
| # # drop_remainder=False, redistribute=True | |
| # print(torch.tensor_split(tensor, chunks, dim=dim)) | |
| # # > (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) | |
| # | |
| # # drop_remainder=True, redistribute=True | |
| # print(torch.tensor_split(tensor[:dim_size - dim_size % chunks], chunks, dim=dim)) | |
| # # > (tensor([0, 1]), tensor([2, 3]), tensor([4, 5])) | |
| # | |
| # # drop_remainder=True, redistribute=False | |
| # print(torch.split(tensor, chunks, dim=dim)) | |
| # # > (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6])) | |
| # | |
| # # drop_remainder=False, redistribute=False | |
| # print(torch.split(tensor[:dim_size - dim_size % chunks], chunks, dim=dim)) | |
| # # > (tensor([0, 1, 2]), tensor([3, 4, 5])) | |
| ##### NUMPY EXAMPLE ##### | |
| # chunks = 3 | |
| # axis = 0 | |
| # arr = np.arange(7) | |
| # axis_size = arr.shape[axis] | |
| # | |
| # # drop_remainder=False, redistribute=True | |
| # print(np.array_split(arr, chunks, axis=axis)) | |
| # # > [array([0, 1, 2]), array([3, 4]), array([5, 6])] | |
| # | |
| # # drop_remainder=True, redistribute=True | |
| # print(np.array_split(arr[:dim_size - dim_size % chunks], chunks, axis=axis)) | |
| # # > [array([0, 1]), array([2, 3]), array([4, 5])] | |
| # | |
| # # drop_remainder=True, redistribute=False -- A bit more tricky, API differs and uses indices | |
| # print(np.split(arr, range(math.ceil(axis_size / chunks), axis_size, chunks), axis=axis)) | |
| # # > [array([0, 1, 2]), array([3, 4, 5]), array([6])] | |
| # | |
| # # drop_remainder=False, redistribute=False -- A bit more tricky, API differs and uses indices | |
| # print(np.split(arr[:dim_size - dim_size % chunks], range(math.ceil(axis_size / chunks), dim_size - dim_size % chunks, chunks), axis=axis)) | |
| # # > [array([0, 1, 2]), array([3, 4, 5])] | |
| import math | |
| import numpy as np | |
| import pytest | |
| import torch | |
| def custom_chunk_torch( | |
| tensor: torch.Tensor, | |
| chunks: int, | |
| dim: int = 0, | |
| redistributed: bool = False, | |
| drop_remainder: bool = False, | |
| zero_returns_arr: bool = False, | |
| ): | |
| # remove items not directly divisible into chunks | |
| if drop_remainder: | |
| remainder = tensor.size(dim) % chunks | |
| if remainder != 0: | |
| tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-remainder], 0, dim) | |
| # split into chunks | |
| dim_size = tensor.size(dim) | |
| if redistributed: | |
| if chunks == 0: | |
| raise ZeroDivisionError | |
| # evenly shift remainder across chunks, so approx evenly sized | |
| return tensor.tensor_split(chunks, dim=dim) | |
| else: | |
| # same result as `tensor.chunk(chunks)` | |
| split_size = math.ceil(dim_size / chunks) | |
| # CUSTOM, diverge from torch.split | |
| if not zero_returns_arr: | |
| if split_size == 0: | |
| return [] | |
| # done | |
| return tensor.split(split_size, dim=dim) | |
| def custom_chunk_np( | |
| arr: np.ndarray, | |
| chunks: int, | |
| axis: int = 0, | |
| drop_remainder: bool = False, | |
| redistributed: bool = False, | |
| zero_returns_arr: bool = False, | |
| ): | |
| # remove items not directly divisible into chunks | |
| if drop_remainder: | |
| remainder = arr.shape[axis] % chunks | |
| if remainder != 0: | |
| arr = np.moveaxis(np.moveaxis(arr, axis, 0)[:-remainder], 0, axis) | |
| # split into chunks | |
| dim_size = arr.shape[axis] | |
| if redistributed: | |
| if chunks == 0: | |
| raise ZeroDivisionError | |
| # evenly shift remainder across chunks, so approx evenly sized | |
| return np.array_split(arr, chunks, axis=axis) | |
| else: | |
| # same result as `tensor.chunk(chunks)` | |
| split_size = math.ceil(dim_size / chunks) | |
| # CUSTOM, diverge from torch.split | |
| if not zero_returns_arr: | |
| if split_size == 0: | |
| return [] | |
| # ----- START mimic torch.split ----- # | |
| indices = range(split_size, dim_size, max(1, split_size)) | |
| # ----- END mimic torch.split ----- # | |
| return np.split(arr, indices, axis=axis) | |
| @pytest.mark.parametrize( | |
| ('fn', 'arange'), | |
| [ | |
| (custom_chunk_np, np.arange), | |
| (custom_chunk_torch, torch.arange), | |
| ] | |
| ) | |
| def test_chuck(fn, arange): | |
| def _check(n, chunks, r, d, expected, z=False): | |
| chunked = fn(arange(n), chunks, redistributed=r, drop_remainder=d, zero_returns_arr=z) | |
| chunked = [chunk.tolist() for chunk in chunked] | |
| assert chunked == expected | |
| _check(3, 1, False, False, [[0, 1, 2]]) | |
| _check(3, 1, True, False, [[0, 1, 2]]) | |
| _check(3, 1, False, True, [[0, 1, 2]]) # BUG? should this not be [] | |
| _check(3, 1, True, True, [[0, 1, 2]]) | |
| # TODO: should maybe return empty array always??? | |
| with pytest.raises(ZeroDivisionError): | |
| _check(3, 0, False, False, None) | |
| with pytest.raises(ZeroDivisionError): | |
| _check(3, 0, True, False, None) | |
| with pytest.raises(ZeroDivisionError): | |
| _check(3, 0, False, True, None) # BUG? should this not be [] | |
| with pytest.raises(ZeroDivisionError): | |
| _check(3, 0, True, True, None) | |
| # TODO: should maybe return empty array always??? | |
| with pytest.raises(ZeroDivisionError): | |
| _check(0, 0, False, False, None) | |
| with pytest.raises(ZeroDivisionError): | |
| _check(0, 0, True, False, None) | |
| with pytest.raises(ZeroDivisionError): | |
| _check(0, 0, False, True, None) # BUG? should this not be [] | |
| with pytest.raises(ZeroDivisionError): | |
| _check(0, 0, True, True, None) | |
| _check(0, 5, False, False, []) | |
| _check(0, 5, False, False, [[]], z=True) # BUG? or correct | |
| _check(0, 5, True, False, [[], [], [], [], []]) | |
| _check(0, 5, False, True, []) | |
| _check(0, 5, False, True, [[]], z=True) # BUG? or correct | |
| _check(0, 5, True, True, [[], [], [], [], []]) | |
| _check(2, 5, False, False, [[0], [1]]) | |
| _check(2, 5, True, False, [[0], [1], [], [], []]) | |
| _check(2, 5, False, True, []) | |
| _check(2, 5, False, True, [[]], z=True) # BUG? or correct | |
| _check(2, 5, True, True, [[], [], [], [], []]) | |
| _check(4, 5, False, False, [[0], [1], [2], [3]]) | |
| _check(4, 5, True, False, [[0], [1], [2], [3], []]) | |
| _check(4, 5, False, True, []) | |
| _check(4, 5, False, True, [[]], z=True) # BUG? or correct | |
| _check(4, 5, True, True, [[], [], [], [], []]) | |
| _check(6, 5, False, False, [[0, 1], [2, 3], [4, 5]]) | |
| _check(6, 5, True, False, [[0, 1], [2], [3], [4], [5]]) | |
| _check(6, 5, False, True, [[0], [1], [2], [3], [4]]) | |
| _check(6, 5, True, True, [[0], [1], [2], [3], [4]]) | |
| _check(8, 5, False, False, [[0, 1], [2, 3], [4, 5], [6, 7]]) | |
| _check(8, 5, True, False, [[0, 1], [2, 3], [4, 5], [6], [7]]) | |
| _check(8, 5, False, True, [[0], [1], [2], [3], [4]]) | |
| _check(8, 5, True, True, [[0], [1], [2], [3], [4]]) | |
| _check(10, 5, False, False, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) | |
| _check(10, 5, True, False, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) | |
| _check(10, 5, False, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) | |
| _check(10, 5, True, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) | |
| _check(12, 5, False, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) | |
| _check(12, 5, True, False, [[0, 1, 2], [3, 4, 5], [6, 7], [8, 9], [10, 11]]) | |
| _check(12, 5, False, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) | |
| _check(12, 5, True, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment