Skip to content

Instantly share code, notes, and snippets.

@nmichlo
Last active July 3, 2025 10:35
Show Gist options
  • Save nmichlo/994bdf5c015c3cbf0b6c379a535fbde5 to your computer and use it in GitHub Desktop.
Save nmichlo/994bdf5c015c3cbf0b6c379a535fbde5 to your computer and use it in GitHub Desktop.
Torch & NumPy redistributed splitting/chunking
# BASED ON @nmichlo's ISSUE AT:
# https://github.com/pytorch/pytorch/issues/60531
# NOTE: modern features of torch and numpy make this easier to implement in python
# as modern one liners.
##### TORCH EXAMPLE #####
# chunks = 3
# dim = 0
# tensor = torch.arange(7)
# dim_size = tensor.size(dim)
#
# # drop_remainder=False, redistribute=True
# print(torch.tensor_split(tensor, chunks, dim=dim))
# # > (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
#
# # drop_remainder=True, redistribute=True
# print(torch.tensor_split(tensor[:dim_size - dim_size % chunks], chunks, dim=dim))
# # > (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
#
# # drop_remainder=True, redistribute=False
# print(torch.split(tensor, chunks, dim=dim))
# # > (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))
#
# # drop_remainder=False, redistribute=False
# print(torch.split(tensor[:dim_size - dim_size % chunks], chunks, dim=dim))
# # > (tensor([0, 1, 2]), tensor([3, 4, 5]))
##### NUMPY EXAMPLE #####
# chunks = 3
# axis = 0
# arr = np.arange(7)
# axis_size = arr.shape[axis]
#
# # drop_remainder=False, redistribute=True
# print(np.array_split(arr, chunks, axis=axis))
# # > [array([0, 1, 2]), array([3, 4]), array([5, 6])]
#
# # drop_remainder=True, redistribute=True
# print(np.array_split(arr[:dim_size - dim_size % chunks], chunks, axis=axis))
# # > [array([0, 1]), array([2, 3]), array([4, 5])]
#
# # drop_remainder=True, redistribute=False -- A bit more tricky, API differs and uses indices
# print(np.split(arr, range(math.ceil(axis_size / chunks), axis_size, chunks), axis=axis))
# # > [array([0, 1, 2]), array([3, 4, 5]), array([6])]
#
# # drop_remainder=False, redistribute=False -- A bit more tricky, API differs and uses indices
# print(np.split(arr[:dim_size - dim_size % chunks], range(math.ceil(axis_size / chunks), dim_size - dim_size % chunks, chunks), axis=axis))
# # > [array([0, 1, 2]), array([3, 4, 5])]
import math
import numpy as np
import pytest
import torch
def custom_chunk_torch(
tensor: torch.Tensor,
chunks: int,
dim: int = 0,
redistributed: bool = False,
drop_remainder: bool = False,
zero_returns_arr: bool = False,
):
# remove items not directly divisible into chunks
if drop_remainder:
remainder = tensor.size(dim) % chunks
if remainder != 0:
tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-remainder], 0, dim)
# split into chunks
dim_size = tensor.size(dim)
if redistributed:
if chunks == 0:
raise ZeroDivisionError
# evenly shift remainder across chunks, so approx evenly sized
return tensor.tensor_split(chunks, dim=dim)
else:
# same result as `tensor.chunk(chunks)`
split_size = math.ceil(dim_size / chunks)
# CUSTOM, diverge from torch.split
if not zero_returns_arr:
if split_size == 0:
return []
# done
return tensor.split(split_size, dim=dim)
def custom_chunk_np(
arr: np.ndarray,
chunks: int,
axis: int = 0,
drop_remainder: bool = False,
redistributed: bool = False,
zero_returns_arr: bool = False,
):
# remove items not directly divisible into chunks
if drop_remainder:
remainder = arr.shape[axis] % chunks
if remainder != 0:
arr = np.moveaxis(np.moveaxis(arr, axis, 0)[:-remainder], 0, axis)
# split into chunks
dim_size = arr.shape[axis]
if redistributed:
if chunks == 0:
raise ZeroDivisionError
# evenly shift remainder across chunks, so approx evenly sized
return np.array_split(arr, chunks, axis=axis)
else:
# same result as `tensor.chunk(chunks)`
split_size = math.ceil(dim_size / chunks)
# CUSTOM, diverge from torch.split
if not zero_returns_arr:
if split_size == 0:
return []
# ----- START mimic torch.split ----- #
indices = range(split_size, dim_size, max(1, split_size))
# ----- END mimic torch.split ----- #
return np.split(arr, indices, axis=axis)
@pytest.mark.parametrize(
('fn', 'arange'),
[
(custom_chunk_np, np.arange),
(custom_chunk_torch, torch.arange),
]
)
def test_chuck(fn, arange):
def _check(n, chunks, r, d, expected, z=False):
chunked = fn(arange(n), chunks, redistributed=r, drop_remainder=d, zero_returns_arr=z)
chunked = [chunk.tolist() for chunk in chunked]
assert chunked == expected
_check(3, 1, False, False, [[0, 1, 2]])
_check(3, 1, True, False, [[0, 1, 2]])
_check(3, 1, False, True, [[0, 1, 2]]) # BUG? should this not be []
_check(3, 1, True, True, [[0, 1, 2]])
# TODO: should maybe return empty array always???
with pytest.raises(ZeroDivisionError):
_check(3, 0, False, False, None)
with pytest.raises(ZeroDivisionError):
_check(3, 0, True, False, None)
with pytest.raises(ZeroDivisionError):
_check(3, 0, False, True, None) # BUG? should this not be []
with pytest.raises(ZeroDivisionError):
_check(3, 0, True, True, None)
# TODO: should maybe return empty array always???
with pytest.raises(ZeroDivisionError):
_check(0, 0, False, False, None)
with pytest.raises(ZeroDivisionError):
_check(0, 0, True, False, None)
with pytest.raises(ZeroDivisionError):
_check(0, 0, False, True, None) # BUG? should this not be []
with pytest.raises(ZeroDivisionError):
_check(0, 0, True, True, None)
_check(0, 5, False, False, [])
_check(0, 5, False, False, [[]], z=True) # BUG? or correct
_check(0, 5, True, False, [[], [], [], [], []])
_check(0, 5, False, True, [])
_check(0, 5, False, True, [[]], z=True) # BUG? or correct
_check(0, 5, True, True, [[], [], [], [], []])
_check(2, 5, False, False, [[0], [1]])
_check(2, 5, True, False, [[0], [1], [], [], []])
_check(2, 5, False, True, [])
_check(2, 5, False, True, [[]], z=True) # BUG? or correct
_check(2, 5, True, True, [[], [], [], [], []])
_check(4, 5, False, False, [[0], [1], [2], [3]])
_check(4, 5, True, False, [[0], [1], [2], [3], []])
_check(4, 5, False, True, [])
_check(4, 5, False, True, [[]], z=True) # BUG? or correct
_check(4, 5, True, True, [[], [], [], [], []])
_check(6, 5, False, False, [[0, 1], [2, 3], [4, 5]])
_check(6, 5, True, False, [[0, 1], [2], [3], [4], [5]])
_check(6, 5, False, True, [[0], [1], [2], [3], [4]])
_check(6, 5, True, True, [[0], [1], [2], [3], [4]])
_check(8, 5, False, False, [[0, 1], [2, 3], [4, 5], [6, 7]])
_check(8, 5, True, False, [[0, 1], [2, 3], [4, 5], [6], [7]])
_check(8, 5, False, True, [[0], [1], [2], [3], [4]])
_check(8, 5, True, True, [[0], [1], [2], [3], [4]])
_check(10, 5, False, False, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
_check(10, 5, True, False, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
_check(10, 5, False, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
_check(10, 5, True, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
_check(12, 5, False, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
_check(12, 5, True, False, [[0, 1, 2], [3, 4, 5], [6, 7], [8, 9], [10, 11]])
_check(12, 5, False, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
_check(12, 5, True, True, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment