Last active
December 23, 2021 01:20
-
-
Save brandonwillard/963393206f48f4dc46e5e9b82f5caed9 to your computer and use it in GitHub Desktop.
Lift advanced indices through concatenate/stack
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Union | |
import numpy as np | |
def is_basic_idx(x): | |
return isinstance(x, (slice, type(None))) | |
def expand_indices( | |
indices: Tuple[Union[np.ndarray, int, slice]], shape: Tuple[int] | |
) -> Tuple[np.ndarray]: | |
"""Convert basic and/or advanced indices (minus the ``None`` case) into a single, broadcasted advanced indexing operation. | |
Example | |
------- | |
>>> indices = ( | |
slice(1, 3), | |
1, | |
slice(None), | |
np.array([2, 1]), | |
) | |
>>> expand_indices(indices, (5, 4, 3, 2)) | |
(array([[[1, 1, 1], | |
[2, 2, 2]], | |
[[1, 1, 1], | |
[2, 2, 2]]]), | |
array([[[1, 1, 1], | |
[1, 1, 1]], | |
[[1, 1, 1], | |
[1, 1, 1]]]), | |
array([[[0, 1, 2], | |
[0, 1, 2]], | |
[[0, 1, 2], | |
[0, 1, 2]]]), | |
array([[[2, 2, 2], | |
[2, 2, 2]], | |
[[1, 1, 1], | |
[1, 1, 1]]])) | |
Parameters | |
---------- | |
indices | |
The indices to convert. | |
shape | |
The shape of the array being indexed. | |
""" | |
n_missing_dims = len(shape) - len(indices) | |
full_indices = list(indices) + [slice(None)] * n_missing_dims | |
# We need to know if a "subspace" was generated by advanced indices | |
# bookending basic indices. If so, we move the advanced indexing subspace | |
# to the "front" of the shape (i.e. left-most indices/last-most | |
# dimensions). | |
index_types = [is_basic_idx(idx) for idx in full_indices] | |
first_adv_idx = len(shape) | |
try: | |
first_adv_idx = index_types.index(False) | |
first_bsc_after_adv_idx = index_types.index(True, first_adv_idx) | |
index_types.index(False, first_bsc_after_adv_idx) | |
moved_subspace = True | |
except ValueError: | |
moved_subspace = False | |
n_basic_indices = sum(index_types) | |
# The number of dimensions in the subspace created by the advanced indices | |
n_subspace_dims = max( | |
( | |
np.ndim(idx) | |
for idx, is_basic in zip(full_indices, index_types) | |
if not is_basic | |
), | |
default=0, | |
) | |
# The number of dimensions for each expanded index | |
n_output_dims = n_subspace_dims + n_basic_indices | |
n_preceding_basics = 0 | |
for d, (idx, s) in enumerate(zip(full_indices, shape)): | |
if not is_basic_idx(idx): | |
idx = np.asarray(idx) | |
if moved_subspace: | |
# The subspace generated by advanced indices appear as the | |
# upper dimensions in the "expanded" index space, so we need to | |
# add broadcast dimensions for the non-basic indices to the end | |
# of these advanced indices | |
expanded_idx = idx[(Ellipsis,) + (None,) * n_basic_indices] | |
else: | |
# In this case, we need to add broadcast dimensions for the | |
# basic indices that proceed and follow the group of advanced | |
# indices; otherwise, a contiguous group of advanced indices | |
# forms a broadcasted set of indices that are iterated over | |
# within the same subspace, which means that all their | |
# corresponding "expanded" indices have exactly the same shape. | |
expanded_idx = idx[(None,) * n_preceding_basics][ | |
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics) | |
] | |
else: | |
if isinstance(idx, slice): | |
idx = np.arange(*idx.indices(s)) | |
elif idx is None: | |
raise NotImplementedError("New axes not supported") | |
if moved_subspace: | |
# Basic indices appear in the lower dimensions | |
# (i.e. right-most) in the output, and are preceded by | |
# the subspace generated by the advanced indices. | |
expanded_idx = idx[(None,) * (n_subspace_dims + n_preceding_basics)][ | |
(Ellipsis,) + (None,) * (n_basic_indices - n_preceding_basics - 1) | |
] | |
else: | |
# In this case, we need to know when the basic indices have | |
# moved past the contiguous group of advanced indices (in the | |
# "expanded" index space), so that we can properly pad those | |
# dimensions in this basic index's shape. | |
# Don't forget that a single advanced index can introduce an | |
# arbitrary number of dimensions to the expanded index space. | |
# If we're currently at a basic index that's past the first | |
# advanced index, then we're necessarily past the group of | |
# advanced indices. | |
n_preceding_dims = ( | |
n_subspace_dims if d > first_adv_idx else 0 | |
) + n_preceding_basics | |
expanded_idx = idx[(None,) * n_preceding_dims][ | |
(Ellipsis,) + (None,) * (n_output_dims - n_preceding_dims - 1) | |
] | |
n_preceding_basics += 1 | |
assert expanded_idx.ndim <= n_output_dims | |
full_indices[d] = expanded_idx | |
return tuple(np.broadcast_arrays(*full_indices)) | |
def test_expand_indices(): | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = (np.array([[0, 1], [2, 2]]), slice(2, 3)) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1, -2)), | |
# np.expand_dims(np.arange(2, 3), (-1,)), | |
# np.arange(A.shape[2]), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# # This works! | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
# This is another way to think about it: | |
# assert np.array_equal(A[indices[0]][:, :, 2:3, :], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (slice(2, 3), np.array([[0, 1], [2, 2]])) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.arange(2, 3), (-1, -2)), | |
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)), | |
# np.expand_dims(np.arange(A.shape[2]), (0, 1, 2)), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# # This works! | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(5, 4, 3)), | |
np.random.normal(size=(5, 4, 3)), | |
np.random.normal(size=(5, 4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = ( | |
np.array([[0], [2], [1]]), | |
slice(None), | |
np.array([2, 1]), | |
slice(2, 3), | |
) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# # 1. Add broadcastable dimensions that equal the number of non-advanced | |
# # indices to the end of each advanced index. | |
# np.expand_dims(np.array([[0], [2], [1]]), (-1, -2)), | |
# # While this is a slice for the second dimension, it is effectively "moved" | |
# # to the dimension *after* the broadcasted subspace created by the advanced | |
# # dimensions. | |
# # 2. Add broadcastable dimensions that equal the number of advanced indices | |
# # to the beginning of each basic index, and additional dimensions for each | |
# # basic index that follows. | |
# np.expand_dims(np.arange(A.shape[1]), (0, 1, -1)), | |
# # 1. | |
# np.expand_dims(np.array([2, 1]), (-1, -2)), | |
# # 2. | |
# np.expand_dims(np.arange(2, 3), (0, 1, 2)), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# A[indices].shape | |
# # This works! | |
# assert np.array_equiv(A[bcast_full_indices], A[indices]) | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (slice(2, 3), np.array([0, 1, 2])) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = (np.array([[0, 1], [2, 2]]), np.array([[0, 1], [2, 2]])) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (2, 2, 3) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
res = A[full_indices] | |
assert np.array_equal(res, exp_res) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = ( | |
np.array([[0, 1], [2, 2]]), | |
np.array([[0, 1], [2, 2]]), | |
np.array([[0, 1], [2, 2]]), | |
) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (2, 2) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
res = A[full_indices] | |
assert np.array_equal(res, exp_res) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = (np.array([[0, 1], [2, 2]]), np.array([[0, 1], [2, 2]]), 1) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (2, 2) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
res = A[full_indices] | |
assert np.array_equal(res, exp_res) | |
# No advanced indices | |
A_parts = ( | |
np.random.normal(size=(5, 4, 3)), | |
np.random.normal(size=(5, 4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (2, 5, 4, 3) | |
indices = (slice(0, 2),) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (2, 5, 4, 3) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
res = A[full_indices] | |
assert np.array_equal(res, exp_res) | |
A_parts = ( | |
np.random.normal(size=(5, 4, 3)), | |
np.random.normal(size=(5, 4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (2, 5, 4, 3) | |
indices = (slice(0, 2), np.random.randint(3, size=(2, 3))) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (2, 2, 3, 4, 3) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
res = A[full_indices] | |
assert np.array_equal(res, exp_res) | |
def test_expand_indices_moved_subspaces(): | |
A_parts = ( | |
np.random.normal(size=(6, 5, 4, 3)), | |
np.random.normal(size=(6, 5, 4, 3)), | |
np.random.normal(size=(6, 5, 4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = ( | |
slice(None), | |
np.array([[0], [2], [1]]), | |
slice(None), | |
np.array([2, 1]), | |
slice(2, 3), | |
) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.arange(A.shape[0]), (0, 1, -1, -2)), | |
# np.expand_dims(np.array([[0], [2], [1]]), (-1, -2, -3)), | |
# np.expand_dims(np.arange(A.shape[2]), (0, 1, 2, -1)), | |
# np.expand_dims(np.array([2, 1]), (-1, -2, -3)), | |
# np.expand_dims(np.arange(2, 3), (0, 1, 2, 3)), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in hand_full_indices] | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# A[indices].shape | |
# # This works! | |
# assert np.array_equiv(A[bcast_full_indices], A[indices]) | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (np.array([[0, 1], [2, 2]]), slice(None), np.array([[0, 1], [2, 2]])) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)), | |
# # While this is a slice for the second dimension, it is effectively "moved" | |
# # to the dimension *after* the broadcasted subspace created by the advanced | |
# # dimensions. | |
# np.expand_dims(np.arange(A.shape[1]), (0, 1)), | |
# np.expand_dims(np.array([[0, 1], [2, 2]]), (-1,)), | |
# ) | |
# [s.shape for s in hand_full_indices] | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [s.shape for s in bcast_full_indices] | |
# A[indices].shape | |
# # This works! | |
# assert np.array_equiv(A[bcast_full_indices], A[indices]) | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
def test_expand_indices_single_indices(): | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (slice(2, 3), np.array([0, 1, 2]), 1) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.arange(2, 3), (-1,)), | |
# np.expand_dims(np.array([0, 1, 2]), (0,)), | |
# np.expand_dims(np.array(1), (0,)), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in hand_full_indices] | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# assert np.array_equiv(A[bcast_full_indices], A[indices]) | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (slice(2, 3), 1, np.array([0, 1, 2])) | |
full_indices = expand_indices(indices, A.shape) | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(A[full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = (1, slice(2, 3), np.array([0, 1, 2])) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (3, 1) | |
full_indices = expand_indices(indices, A.shape) | |
res = A[full_indices] | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(res, exp_res) | |
# Let's do it by hand: | |
# hand_full_indices = ( | |
# np.expand_dims(np.array(1), (-1,)), | |
# np.expand_dims(np.arange(2, 3), (0,)), | |
# np.expand_dims(np.array([0, 1, 2]), (-1,)), | |
# ) | |
# bcast_full_indices = tuple(np.broadcast_arrays(*hand_full_indices)) | |
# [idx.shape for idx in bcast_full_indices] | |
# [idx.shape for idx in full_indices] | |
# assert np.array_equiv(A[bcast_full_indices].flat, A[indices].flat) | |
# assert np.array_equal(A[bcast_full_indices], A[indices]) | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
# A.shape | |
# (3, 4, 3) | |
indices = (np.random.randint(2, size=(4, 3)), 1, 0) | |
exp_res = A[indices] | |
# exp_res.shape | |
# (4, 3) | |
full_indices = expand_indices(indices, A.shape) | |
res = A[full_indices] | |
assert len(full_indices) == A.ndim | |
assert np.array_equal(res, exp_res) | |
def reorder_index(A_parts, indices, join_index=0): | |
"""Compute `A[indices]` for `A = np.concatenate(A_parts)`.""" | |
A_shape = list(A_parts[0].shape) | |
A_shape.insert(join_index, len(A_parts)) | |
bcast_indices = expand_indices(indices, A_shape) | |
res = np.empty(bcast_indices[0].shape) | |
for m, A_part in enumerate(A_parts): | |
# Get the indices for group-`m` entries in the indices' first dimensions | |
# (i.e. the dimension on the indexed array's, `A`, join axis) | |
m_0 = np.nonzero(bcast_indices[join_index] == m) | |
# Get the corresponding group-`m` indices for all the other dimensions | |
m_idx = tuple(v[m_0] for i, v in enumerate(bcast_indices) if i != join_index) | |
# Apply the group-`m` indices to the group-`m` subspace in the indexed | |
# array (i.e. `A`). | |
res[m_0] = A_part[m_idx] | |
return res | |
def test_reorder_index(): | |
A_parts = ( | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
np.random.normal(size=(4, 3)), | |
) | |
A = np.stack(A_parts) | |
indices = (np.random.randint(2, size=(4, 3)), 1, 0) | |
res = reorder_index(A_parts, indices) | |
assert np.array_equal(res, A[indices]) | |
A = np.stack(A_parts, axis=1) | |
res = reorder_index(A_parts, indices, join_index=1) | |
assert np.array_equal(res, A[indices]) | |
indices = (np.random.randint(2, size=(4, 3)),) | |
A = np.stack(A_parts, axis=1) | |
res = reorder_index(A_parts, indices, join_index=1) | |
assert np.array_equal(res, A[indices]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment