Created
July 2, 2024 15:45
-
-
Save amjames/59ae75ca5d197be78cc5e7f7e1c2c909 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
""" | |
from typing import Type, Tuple, Union | |
from functools import reduce | |
Scalar = Union[int, float] | |
ScalarType = Union[Type[int], Type[float]] | |
class Array: | |
def __init__(self, shape: Tuple[int], value: Scalar = 0, dtype: ScalarType = float): | |
self.shape = shape | |
self.dtype = dtype | |
self.ndim = len(shape) | |
self._nele = reduce((lambda x, y: x*y), shape) | |
self._data = [dtype(value)] * self._nele | |
@classmethod | |
def ones(cls, shape: Tuple[int], dtype: ScalarType =float): | |
return cls(shape, 1, dtype) | |
@classmethod | |
def ones_like(cls, other: "Array"): | |
return cls(other.shape, 1, other.dtype) | |
def __getitem__(self, idx: Tuple[int]) -> Union[Scalar, "Array"]: | |
... | |
def __setitem__(self, idx: Tuple[Union[int, slice]], value: Union[Scalar, "Array"]): | |
... | |
def cos(a: Array): | |
""" | |
B[i] = cos(A[i]) | |
""" | |
... | |
def add(a: Array, b: Array): | |
""" | |
C[i] = A[i] + B[i] | |
C[i, j] = A[i, j] + B[i, j] | |
C[i, j] = A[i, j] + B[i, 1] | |
C[i, j] = A[i, j] + B[j] | |
""" | |
... | |
def dot(a: Array, b: Array): | |
""" | |
C[i,j] = \Sum_{k} A[i,j,k] * B[i, j, k] | |
""" | |
def matmul(a: Array, b: Array): | |
""" | |
C[i, j] = \Sum_{k} A[i, k] * B[k, j] | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment