Last active
January 7, 2024 09:00
-
-
Save altescy/38f58a61df413159afb7b0933c2cfef1 to your computer and use it in GitHub Desktop.
Numpy-like tensor implementation written in pure Python.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import itertools | |
import math | |
from collections.abc import Callable, Iterable, Iterator, Sequence | |
from functools import partial, reduce | |
from operator import mul | |
from types import EllipsisType | |
from typing import Generic, Type, TypeAlias, TypeVar, Union, cast, overload | |
Numeric: TypeAlias = bool | int | float | complex | |
DType = TypeVar("DType", bool, int, float, complex) | |
T_DType = TypeVar("T_DType", bool, int, float, complex) | |
class Tensor(Generic[DType]): | |
def __init__( | |
self, | |
data: DType | Sequence[DType], | |
shape: Sequence[int] | None = None, | |
) -> None: | |
self._data: tuple[DType, ...] = tuple(data) if isinstance(data, Sequence) else (data,) | |
self._shape: tuple[int, ...] = ( | |
tuple(shape) if shape is not None else (len(data),) if isinstance(data, Sequence) else () | |
) | |
if not self.shape: | |
if self.data and len(self.data) != 1: | |
raise ValueError("shape must be specified if data has more than one element") | |
if any(k < 0 for k in self.shape): | |
raise ValueError("shape must be non-negative") | |
else: | |
if len(self.data) != reduce(mul, self.shape, 1): | |
raise ValueError("data size does not match shape") | |
@property | |
def data(self) -> tuple[DType, ...]: | |
return self._data | |
@property | |
def shape(self) -> tuple[int, ...]: | |
return self._shape | |
@property | |
def ndim(self) -> int: | |
return len(self.shape) if self.shape else 0 | |
@property | |
def size(self) -> int: | |
return len(self.data) | |
@property | |
def value(self) -> DType: | |
if self.size != 1: | |
raise ValueError("tensor must have exactly one element") | |
return self.data[0] | |
@staticmethod | |
def get_nested_index(flattened_index: int, shape: Sequence[int]) -> tuple[int, ...]: | |
index = flattened_index | |
multi_index = [] | |
for dim_size in reversed(shape): | |
multi_index.append(index % dim_size) | |
index //= dim_size | |
return tuple(reversed(multi_index)) | |
@staticmethod | |
def get_flattened_index(multi_index: Sequence[int], shape: Sequence[int]) -> int: | |
return sum(dim_index * reduce(mul, shape[i + 1 :], 1) for i, dim_index in enumerate(multi_index)) | |
@staticmethod | |
def normalize_index( | |
index: tuple[int | slice | Sequence[int] | EllipsisType | None, ...], shape: Sequence[int] | |
) -> tuple[int | slice | Sequence[int] | None, ...]: | |
ndim = len(shape) | |
num_specified_indices = sum(i not in (None, Ellipsis) for i in index) | |
if num_specified_indices > ndim: | |
raise IndexError("too many indices for tensor") | |
ellipsis_count = index.count(Ellipsis) | |
if ellipsis_count == 0: | |
num_slices_for_rest_dims = ndim - num_specified_indices | |
index = index + (slice(None),) * num_slices_for_rest_dims | |
elif ellipsis_count == 1: | |
ellipsis_index = index.index(Ellipsis) | |
num_slices_for_ellipsis = ndim - num_specified_indices | |
index = index[:ellipsis_index] + (slice(None),) * num_slices_for_ellipsis + index[ellipsis_index + 1 :] | |
else: | |
raise IndexError("only one ellipsis allowed") | |
return cast(tuple[int | slice | None, ...], index) | |
def __repr__(self) -> str: | |
def format_tensor(level: int, shape: Sequence[int], data: Sequence[Numeric]) -> str: | |
if not shape: | |
return str(data[0]) | |
dim = shape[0] | |
extra_size = reduce(mul, shape[1:], 1) | |
sub_tensors = [ | |
format_tensor(level + 1, shape[1:], data[i * extra_size : (i + 1) * extra_size]) for i in range(dim) | |
] | |
if level == self.ndim - 1: | |
return f"[{', '.join(sub_tensors)}]" | |
newline = "\n" | |
indent = " " * level | |
return f"[{newline}{(',' + newline).join(indent + ' ' + sub for sub in sub_tensors)}{newline + indent}]" | |
return f"Tensor({format_tensor(0, self.shape, self.data)})" | |
@overload | |
def __add__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __add__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __add__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __add__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __add__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __add__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a + other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a + b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __radd__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __radd__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __radd__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __radd__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __radd__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __radd__(self, other: Numeric | Tensor) -> Tensor: | |
return self + other | |
@overload | |
def __sub__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __sub__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __sub__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __sub__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __sub__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __sub__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex, Tensor)): | |
return self + (-other) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __rsub__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __rsub__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __rsub__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __rsub__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __rsub__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __rsub__(self, other: Numeric | Tensor) -> Tensor: | |
return -(self - other) | |
@overload | |
def __mul__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __mul__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __mul__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __mul__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __mul__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __mul__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a * other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
if self.shape != other.shape: | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a * b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __rmul__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __rmul__(self: Tensor[float], other: bool | int | float | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __rmul__(self: Tensor[bool] | Tensor[int] | Tensor[float], other: float | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __rmul__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
def __rmul__(self, other: Numeric | Tensor) -> Tensor: | |
return self * other | |
@overload | |
def __truediv__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], | |
other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float], | |
) -> Tensor[float]: | |
... | |
@overload | |
def __truediv__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __truediv__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __truediv__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a / other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a / b for a, b in zip(left.data, right.data)), self.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __rtruediv__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], | |
other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float], | |
) -> Tensor[float]: | |
... | |
@overload | |
def __rtruediv__(self: Tensor[complex], other: Numeric | Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __rtruediv__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __rtruediv__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(other / a for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return Tensor(tuple(a / b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __floordiv__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __floordiv__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
def __floordiv__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor | |
) -> Tensor[int] | Tensor[float]: | |
if isinstance(other, (int, float)): | |
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(a // other for a in self.data), self.shape)) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return cast( | |
Union[Tensor[int], Tensor[float]], | |
Tensor(tuple(a // b for a, b in zip(left.data, right.data)), left.shape), | |
) | |
return NotImplemented | |
@overload | |
def __rfloordiv__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __rfloordiv__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
def __rfloordiv__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor | |
) -> Tensor[int] | Tensor[float]: | |
if isinstance(other, (int, float)): | |
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(other // a for a in self.data), self.shape)) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return cast( | |
Union[Tensor[int], Tensor[float]], | |
Tensor(tuple(b // a for a, b in zip(left.data, right.data)), left.shape), | |
) | |
return NotImplemented | |
@overload | |
def __mod__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __mod__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
def __mod__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor | |
) -> Tensor[int] | Tensor[float]: | |
if isinstance(other, (int, float)): | |
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(a % other for a in self.data), self.shape)) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return cast( | |
Union[Tensor[int], Tensor[float]], | |
Tensor(tuple(a % b for a, b in zip(left.data, right.data)), left.shape), | |
) | |
return NotImplemented | |
@overload | |
def __rmod__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __rmod__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
def __rmod__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Numeric | Tensor | |
) -> Tensor[int] | Tensor[float]: | |
if isinstance(other, (int, float)): | |
return cast(Union[Tensor[int], Tensor[float]], Tensor(tuple(other % a for a in self.data), self.shape)) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return cast( | |
Union[Tensor[int], Tensor[float]], | |
Tensor(tuple(b % a for a, b in zip(left.data, right.data)), left.shape), | |
) | |
return NotImplemented | |
@overload | |
def __pow__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __pow__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
@overload | |
def __pow__(self: Tensor[complex], other: Numeric | Tensor[complex]) -> Tensor[complex]: | |
... | |
@overload | |
def __pow__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __pow__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a**other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a**b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __rpow__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __rpow__( | |
self: Tensor[float], other: bool | int | float | Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[float]: | |
... | |
@overload | |
def __rpow__(self: Tensor[complex], other: Numeric | Tensor[complex]) -> Tensor[complex]: | |
... | |
@overload | |
def __rpow__(self, other: complex | Tensor[complex]) -> Tensor[complex]: | |
... | |
def __rpow__(self, other: Numeric | Tensor) -> Tensor: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(other**a for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return Tensor(tuple(a**b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented # type: ignore[unreachable] | |
@overload | |
def __matmul__(self: Tensor[bool] | Tensor[int], other: Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __matmul__(self: Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __matmul__(self: Tensor[complex], other: Tensor) -> Tensor[complex]: | |
... | |
@overload | |
def __matmul__(self, other: Tensor[complex]) -> Tensor[complex]: | |
... | |
def __matmul__(self, other: Tensor) -> Tensor: | |
if not isinstance(other, Tensor): | |
return NotImplemented # type: ignore[unreachable] | |
if self.ndim > 2 or other.ndim > 2: | |
raise ValueError("matmul requires 2D tensors") | |
left, right = self, other | |
if left.ndim == 1: | |
left = left.reshape((1, left.shape[0])) | |
if right.ndim == 1: | |
right = right.reshape((right.shape[0], 1)) | |
if left.shape[1] != right.shape[0]: | |
raise ValueError(f"shape mismatch: {self.shape} @ {other.shape}") | |
output_shape = (left.shape[0], right.shape[1]) | |
output_data: list[Numeric] = [0] * reduce(mul, output_shape, 1) | |
for i in range(output_shape[0]): | |
for j in range(output_shape[1]): | |
for k in range(left.shape[1]): | |
output_data[i * other.shape[1] + j] += ( | |
left.data[i * self.shape[1] + k] * right.data[k * other.shape[1] + j] | |
) | |
return Tensor(output_data, output_shape) | |
def __and__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]: | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a and b for a, b in zip(left.data, right.data)), left.shape) | |
def __or__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]: | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a or b for a, b in zip(left.data, right.data)), left.shape) | |
def __xor__(self: Tensor[bool], other: Tensor[bool]) -> Tensor[bool]: | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a ^ b for a, b in zip(left.data, right.data)), left.shape) | |
def __eq__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override] | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a == other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a == b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
def __req__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override] | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(other == a for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return Tensor(tuple(b == a for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
def __ne__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override] | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a != other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a != b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
def __rne__(self, other: Numeric | Tensor) -> Tensor[bool]: # type: ignore[override] | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(other != a for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return Tensor(tuple(b != a for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
@overload | |
def __lt__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[bool]: | |
... | |
@overload | |
def __lt__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]: | |
... | |
def __lt__(self, other: Numeric | Tensor) -> Tensor[bool]: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a < other for a in self.data), self.shape) # type: ignore[operator] | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a < b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
@overload | |
def __le__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[bool]: | |
... | |
@overload | |
def __le__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]: | |
... | |
def __le__(self, other: Numeric | Tensor) -> Tensor[bool]: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a <= other for a in self.data), self.shape) # type: ignore[operator] | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a <= b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
@overload | |
def __gt__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[bool]: | |
... | |
@overload | |
def __gt__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]: | |
... | |
def __gt__(self, other: Numeric | Tensor) -> Tensor[bool]: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a > other for a in self.data), self.shape) # type: ignore[operator] | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a > b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
@overload | |
def __ge__( | |
self: Tensor[bool] | Tensor[int] | Tensor[float], other: Tensor[bool] | Tensor[int] | Tensor[float] | |
) -> Tensor[bool]: | |
... | |
@overload | |
def __ge__(self: Tensor[complex], other: Tensor[complex]) -> Tensor[bool]: | |
... | |
def __ge__(self, other: Numeric | Tensor) -> Tensor[bool]: | |
if isinstance(other, (int, float, complex)): | |
return Tensor(tuple(a >= other for a in self.data), self.shape) # type: ignore[operator] | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a >= b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
def __lshift__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
if isinstance(other, (int, bool)): | |
return Tensor(tuple(a << other for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(self, other) | |
return Tensor(tuple(a << b for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
def __rshift__(self: Tensor[bool] | Tensor[int], other: bool | int | Tensor[bool] | Tensor[int]) -> Tensor[int]: | |
if isinstance(other, (int, bool)): | |
return Tensor(tuple(other << a for a in self.data), self.shape) | |
if isinstance(other, Tensor): | |
left, right = self.broadcast_tensors(other, self) | |
return Tensor(tuple(b << a for a, b in zip(left.data, right.data)), left.shape) | |
return NotImplemented | |
@overload | |
def __pos__(self: Tensor[bool]) -> Tensor[int]: | |
... | |
@overload | |
def __pos__(self: Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __pos__(self: Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __pos__(self: Tensor[complex]) -> Tensor[complex]: | |
... | |
def __pos__(self) -> Tensor: | |
return Tensor(tuple(+a for a in self.data), self.shape) | |
@overload | |
def __neg__(self: Tensor[bool]) -> Tensor[int]: | |
... | |
@overload | |
def __neg__(self: Tensor[int]) -> Tensor[int]: | |
... | |
@overload | |
def __neg__(self: Tensor[float]) -> Tensor[float]: | |
... | |
@overload | |
def __neg__(self: Tensor[complex]) -> Tensor[complex]: | |
... | |
def __neg__(self) -> Tensor: | |
return Tensor(tuple(-a for a in self.data), self.shape) | |
def __invert__(self: Tensor[bool]) -> Tensor[bool]: | |
return Tensor(tuple(not a for a in self.data), self.shape) | |
def __getitem__( | |
self, | |
idx: int | slice | Sequence[int] | Tensor[bool] | tuple[int | slice | Sequence[int] | EllipsisType | None, ...], | |
) -> Tensor[DType]: | |
if not self.shape: | |
raise IndexError("scalar tensor cannot be indexed") | |
flattened_indices: Iterable[int] | |
output_shape: tuple[int, ...] | |
extra_size = reduce(mul, self.shape[1:], 1) | |
if isinstance(idx, int): | |
if idx >= self.shape[0]: | |
raise IndexError("index out of range") | |
if idx < 0: | |
idx += self.shape[0] | |
flattened_indices = iter(range(idx * extra_size, (idx + 1) * extra_size)) | |
output_shape = self.shape[1:] | |
elif isinstance(idx, slice): | |
start, stop, step = idx.indices(self.shape[0]) | |
flattened_indices = itertools.chain.from_iterable( | |
range(i * extra_size, (i + 1) * extra_size) for i in range(start, stop, step) | |
) | |
output_shape = (math.ceil((stop - start) / step),) + self.shape[1:] | |
elif isinstance(idx, tuple): | |
current_dim = 0 | |
output_shape = () | |
index_iterators: list[Iterable[int]] = [] | |
for output_dim, index in enumerate(self.normalize_index(idx, self.shape)): | |
if index is None: | |
output_shape = output_shape + (1,) | |
continue | |
elif isinstance(index, int): | |
if index >= self.shape[current_dim]: | |
raise IndexError("index out of range") | |
if index < 0: | |
index += self.shape[current_dim] | |
index_iterators.append((index,)) | |
elif isinstance(index, slice): | |
start, stop, step = index.indices(self.shape[current_dim]) | |
index_iterators.append(range(start, stop, step)) | |
output_shape = output_shape + (math.ceil((stop - start) / step),) | |
elif isinstance(index, Sequence): | |
index_iterators.append(index) | |
output_shape = output_shape + (len(index),) | |
else: | |
raise IndexError("index must be int, slice, Ellipsis, or None, but got " + repr(index)) | |
current_dim += 1 | |
flattened_indices = map( | |
partial(self.get_flattened_index, shape=self.shape), | |
itertools.product(*index_iterators), | |
) | |
if output_shape == (1,) and all(isinstance(i, int) for i in output_shape): | |
output_shape = () | |
elif isinstance(idx, Sequence): | |
idx = cast(Sequence[int], idx) | |
if any(i >= self.shape[0] for i in idx): | |
raise IndexError("index out of range") | |
flattened_indices = itertools.chain.from_iterable(range(i * extra_size, (i + 1) * extra_size) for i in idx) | |
output_shape = (len(idx),) + self.shape[1:] | |
elif isinstance(idx, Tensor): | |
if idx.shape != self.shape: | |
raise ValueError("shape mismatch") | |
flattened_indices = [i for i in range(self.size) if idx.data[i]] | |
output_shape = (len(flattened_indices),) | |
return Tensor([self.data[i] for i in flattened_indices], output_shape) | |
def __iter__(self) -> Iterator[Tensor[DType]]: | |
if not self.shape: | |
raise TypeError("'Tensor' object is not iterable") | |
for i in range(self.shape[0]): | |
yield self[i] | |
def copy(self) -> Tensor[DType]: | |
return Tensor(self.data, self.shape) | |
def reshape(self, shape: Sequence[int]) -> Tensor[DType]: | |
neg_count = 0 | |
neg_index = -1 | |
for i, s in enumerate(shape): | |
if s < 0: | |
neg_count += 1 | |
neg_index = i | |
if neg_count > 1: | |
raise ValueError("can only specify one unknown dimension") | |
elif neg_count == 1: | |
shape = list(shape) | |
shape[neg_index] = -1 | |
shape[neg_index] = self.size // -reduce(mul, shape, 1) | |
if reduce(mul, shape, 1) != self.size: | |
raise ValueError("cannot reshape tensor of size {} into shape {}".format(self.shape, shape)) | |
return Tensor(self.data, shape) | |
def transpose(self, dim0: int, dim1: int) -> Tensor[DType]: | |
if not (0 <= dim0 < self.ndim and 0 <= dim1 < self.ndim): | |
raise ValueError("dimension out of range") | |
if dim0 == dim1: | |
return self.copy() | |
transposed_shape = list(self.shape) | |
transposed_shape[dim0], transposed_shape[dim1] = transposed_shape[dim1], transposed_shape[dim0] | |
def get_transposed_index(i: int) -> int: | |
multi_index = list(self.get_nested_index(i, self.shape)) | |
multi_index[dim0], multi_index[dim1] = multi_index[dim1], multi_index[dim0] | |
transposed_index = self.get_flattened_index(multi_index, transposed_shape) | |
return transposed_index | |
transposed_data = list(self.data) | |
for original_index, transposed_index in enumerate(map(get_transposed_index, range(self.size))): | |
transposed_data[transposed_index] = self.data[original_index] | |
return Tensor(transposed_data, transposed_shape) | |
def is_equal(self, other: object) -> bool: | |
if not isinstance(other, Tensor): | |
return NotImplemented | |
return self.shape == other.shape and self.data == other.data | |
@overload | |
def sum(self: Tensor[bool] | Tensor[int], dim: int | None = ...) -> Tensor[int]: | |
... | |
@overload | |
def sum(self: Tensor[float], dim: int | None = ...) -> Tensor[float]: | |
... | |
@overload | |
def sum(self: Tensor[complex], dim: int | None = ...) -> Tensor[complex]: | |
... | |
def sum(self, dim: int | None = None) -> Tensor: | |
if dim is None: | |
return Tensor(sum(self.data)) | |
if not (0 <= dim < self.ndim): | |
raise ValueError("dimension out of range") | |
output: Tensor = Tensor(0) | |
for i in range(self.shape[dim]): | |
index: list[slice | int] = [slice(None)] * self.ndim | |
index[dim] = i | |
output += self[tuple(index)] | |
return output | |
def apply(self, fn: Callable[[DType], T_DType]) -> Tensor[T_DType]: | |
return Tensor(tuple(map(fn, self.data)), self.shape) | |
def exp(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.exp) | |
def expm1(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.expm1) | |
def log(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.log) | |
def log10(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.log10) | |
def log1p(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.log1p) | |
def log2(self: Tensor[bool] | Tensor[int] | Tensor[float]) -> Tensor[float]: | |
return self.apply(math.log2) | |
@overload | |
def astype(self, dtype: Type[bool]) -> Tensor[bool]: # type: ignore[misc] | |
... | |
@overload | |
def astype(self: Tensor[bool] | Tensor[int] | Tensor[float], dtype: Type[int]) -> Tensor[int]: | |
... | |
@overload | |
def astype(self: Tensor[bool] | Tensor[int] | Tensor[float], dtype: Type[float]) -> Tensor[float]: | |
... | |
@overload | |
def astype(self, dtype: Type[complex]) -> Tensor[complex]: | |
... | |
def astype(self, dtype: Type[Numeric]) -> Tensor[bool] | Tensor[int] | Tensor[float] | Tensor[complex]: | |
return Tensor(tuple(dtype(a) for a in self.data), self.shape) # type: ignore[arg-type, misc, call-overload] | |
def broadcast_to(self, shape: Sequence[int]) -> Tensor[DType]: | |
if self.shape == shape: | |
return self.copy() | |
if len(self.shape) > len(shape): | |
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape)) | |
if self.shape == (): | |
return Tensor([self.data[0]] * reduce(mul, shape, 1), shape) | |
original_shape = [1] * (len(shape) - len(self.shape)) + list(self.shape) | |
broadcasted_shape = shape | |
index_iterators: list[Iterable[int]] = [] | |
for original_dim_size, broadcasted_dim_size in zip(original_shape, broadcasted_shape): | |
if original_dim_size == broadcasted_dim_size: | |
index_iterators.append(range(original_dim_size)) | |
elif original_dim_size == 1: | |
if broadcasted_dim_size < 1: | |
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape)) | |
index_iterators.append([0] * broadcasted_dim_size) | |
else: | |
raise ValueError("cannot broadcast tensor of shape {} to shape {}".format(self.shape, shape)) | |
broadcasted_indices = map( | |
partial(self.get_flattened_index, shape=self.shape), | |
itertools.product(*index_iterators), | |
) | |
return Tensor([self.data[i] for i in broadcasted_indices], broadcasted_shape) | |
@staticmethod | |
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]: | |
max_ndim = max(len(shape) for shape in shapes) | |
shapes = tuple((1,) * (max_ndim - len(shape)) + shape for shape in shapes) | |
output_shape = [] | |
for dim_sizes in zip(*shapes): | |
if len(set(dim_sizes) - {1}) > 1: | |
raise ValueError("cannot broadcast shapes {} to a common shape".format(shapes)) | |
output_shape.append(max(dim_sizes)) | |
return tuple(output_shape) | |
@staticmethod | |
def broadcast_tensors(*tensors: Tensor) -> tuple[Tensor, ...]: | |
output_shape = Tensor.broadcast_shapes(*[tensor.shape for tensor in tensors]) | |
return tuple(tensor.broadcast_to(output_shape) for tensor in tensors) | |
@classmethod | |
def zeros(cls, shape: Sequence[int]) -> Tensor[int]: | |
return Tensor([0] * reduce(mul, shape, 1), shape) | |
@classmethod | |
def ones(cls, shape: Sequence[int]) -> Tensor[int]: | |
return Tensor([1] * reduce(mul, shape, 1), shape) | |
@classmethod | |
def eye(cls, n: int) -> Tensor[int]: | |
return Tensor([1 if i == j else 0 for i in range(n) for j in range(n)], (n, n)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment