Created
January 14, 2025 17:44
-
-
Save RandallPittmanOrSt/0c049b2604a3260c9811f54249a028f2 to your computer and use it in GitHub Desktop.
ndarray shape and type narrowing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""nptypes.py - Various useful types and type aliases""" | |
# FUTURE: Integrate optype | |
from typing import Any, Literal, Protocol, TypeAlias, TypeVar | |
import numpy as np | |
import numpy.typing as npt | |
from typing_extensions import TypeIs | |
# type aliases | |
type F64Arr = npt.NDArray[np.float64] | |
# ---stuff copied from numpy._typing--- | |
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any]) | |
class _SupportsDType(Protocol[_DType_co]): | |
@property | |
def dtype(self) -> _DType_co: ... | |
type _DTypeLike[_SCT: np.generic] = ( | |
np.dtype[_SCT] | type[_SCT] | _SupportsDType[np.dtype[_SCT]] | |
) | |
_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True) | |
# ---End of stuff from numpy._typing--- | |
# We have to use old syntax for now for the base array type-aliases because the new type | |
# alias syntax doesn't have way to specify covariance for the type variables, and you | |
# can't use old-stype TypeVars in the new type alias syntax. | |
NDArr1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarType_co]] | |
"""1-D Numpy array""" | |
NDArr2D: TypeAlias = np.ndarray[tuple[int, int], np.dtype[_ScalarType_co]] | |
"""2-D Numpy array""" | |
Nx2Arr: TypeAlias = np.ndarray[tuple[int, Literal[2]], np.dtype[_ScalarType_co]] | |
"""Nx2 Numpy array""" | |
Nx3Arr: TypeAlias = np.ndarray[tuple[int, Literal[3]], np.dtype[_ScalarType_co]] | |
"""Nx3 Numpy array""" | |
NDArr3D: TypeAlias = np.ndarray[tuple[int, int, int], np.dtype[_ScalarType_co]] | |
"""3-D Numpy array""" | |
type F64Arr1D = NDArr1D[np.float64] | |
"""1-D float64 array""" | |
type F64Arr2D = NDArr2D[np.float64] | |
"""2-D float64 array""" | |
type F64Arr3D = NDArr3D[np.float64] | |
"""3-D float64 array""" | |
def is_arr_dtype[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | |
) -> TypeIs[npt.NDArray[DT]]: | |
"""Ensure an array's dtype is dtype.""" | |
return bool(np.issubdtype(arr.dtype, dtype)) | |
def is_ndarr1d[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | None = None | |
) -> TypeIs[NDArr1D[DT]]: | |
"""Check if a variable is a 1-D NumPy array (possibly with a particular type).""" | |
return ( | |
(dtype is None or is_arr_dtype(arr, dtype)) | |
and arr.ndim == 1 | |
) | |
def is_ndarr2d[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | None = None | |
) -> TypeIs[NDArr2D[DT]]: | |
"""Check if a variable is a 2-D NumPy array (possibly with a particular type).""" | |
return ( | |
(dtype is None or is_arr_dtype(arr, dtype)) | |
and arr.ndim == 2 | |
) | |
def is_nx2_arr[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | None = None | |
) -> TypeIs[Nx2Arr[DT]]: | |
"""Check if a variable is an Nx2 NumPy array (possibly with a particular type).""" | |
return is_ndarr2d(arr, dtype) and arr.shape[1] == 2 | |
def is_nx3_arr[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | None = None | |
) -> TypeIs[Nx3Arr[DT]]: | |
"""Check if a variable is an Nx3 NumPy array (possibly with a particular type).""" | |
return is_ndarr2d(arr, dtype) and arr.shape[1] == 3 | |
def is_ndarr3d[DT: np.generic]( | |
arr: np.ndarray, dtype: _DTypeLike[DT] | None = None | |
) -> TypeIs[NDArr3D[DT]]: | |
"""Check if a variable is a 3-D NumPy array (possibly with a particular type).""" | |
return ( | |
(dtype is None or is_arr_dtype(arr, dtype)) | |
and arr.ndim == 3 | |
) | |
def is_F64Arr1D(arr: np.ndarray) -> TypeIs[F64Arr1D]: | |
"""Check if a val is a 1-D float64 ndarray.""" | |
return is_ndarr1d(arr) and is_arr_dtype(arr, np.float64) | |
def is_F64Arr2D(arr: np.ndarray) -> TypeIs[F64Arr2D]: | |
"""Check if a val is a 2-D float64 ndarray.""" | |
return is_ndarr2d(arr) and is_arr_dtype(arr, np.float64) | |
def is_F64Arr3D(arr: np.ndarray) -> TypeIs[F64Arr3D]: | |
"""Check if a val is a 3-D float64 ndarray.""" | |
return is_ndarr3d(arr) and is_arr_dtype(arr, np.float64) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment