Last active
January 17, 2023 17:46
-
-
Save Finndersen/353747470f375c7c17147214f357af7a to your computer and use it in GitHub Desktop.
Definition of VectorArray for Pandas Extension Types example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VectorArray(ExtensionScalarOpsMixin, ExtensionArray): | |
""" | |
Custom Extension Array type for an array of Vectors | |
Needs to define: | |
- Associated Dtype it is used with | |
- How to construct array from sequence of scalars | |
- How data is stored and accessed | |
- Any custom array methods | |
""" | |
def __init__(self, x_values, y_values, copy=False): | |
""" | |
Initialise array of vectors from component X and Y values | |
(Allows efficient initialisation from existing lists/arrays) | |
:param x_values: Sequence/array of vector x-component values | |
:param y_values: Sequence/array of vector y-component values | |
""" | |
self.x_values = np.array(x_values, dtype=np.float64, copy=copy) | |
self.y_values = np.array(y_values, dtype=np.float64, copy=copy) | |
@classmethod | |
def _from_sequence(cls, scalars, *, dtype=None, copy=False): | |
""" | |
Construct a new ExtensionArray from a sequence of scalars. | |
Each element will be an instance of the scalar type for this array, | |
or be converted into this type in this method. | |
""" | |
# Construct new array from sequence of values (Unzip vectors into x and y components) | |
x_values, y_values = zip(*[create_vector(val).as_tuple() for val in scalars]) | |
return VectorArray(x_values, y_values, copy=copy) | |
@classmethod | |
def from_vectors(cls, vectors): | |
""" | |
Construct array from sequence of values (vectors) | |
Can be provided as Vector instances or list/tuple like (x, y) pairs | |
""" | |
return cls._from_sequence(vectors) | |
@classmethod | |
def _concat_same_type(cls, to_concat): | |
""" | |
Concatenate multiple arrays of this dtype | |
""" | |
return VectorArray( | |
np.concatenate(arr.x_values for arr in to_concat), | |
np.concatenate(arr.y_values for arr in to_concat), | |
) | |
@property | |
def dtype(self): | |
""" | |
Return Dtype instance (not class) associated with this Array | |
""" | |
return VectorDtype() | |
@property | |
def nbytes(self): | |
""" | |
The number of bytes needed to store this object in memory. | |
""" | |
return self.x_values.nbytes + self.y_values.nbytes | |
def __getitem__(self, item): | |
""" | |
Retrieve single item or slice | |
""" | |
if isinstance(item, int): | |
# Get single vector | |
return Vector(self.x_values[item], self.y_values[item]) | |
else: | |
# Get subset from slice or boolean array | |
return VectorArray(self.x_values[item], self.y_values[item]) | |
def __eq__(self, other): | |
""" | |
Perform element-wise equality with a given vector value | |
""" | |
if isinstance(other, (pd.Index, pd.Series, pd.DataFrame)): | |
return NotImplemented | |
return (self.x_values == other[0]) & (self.y_values == other[1]) | |
def __len__(self): | |
return self.x_values.size | |
def isna(self): | |
""" | |
Returns a 1-D array indicating if each value is missing | |
""" | |
return np.isnan(self.x_values) | |
def take(self, indices, *, allow_fill=False, fill_value=None): | |
""" | |
Take element from array using positional indexing | |
""" | |
from pandas.core.algorithms import take | |
if allow_fill and fill_value is None: | |
fill_value = self.dtype.na_value | |
x_result = take(self.x_values, indices, fill_value=fill_value, allow_fill=allow_fill) | |
y_result = take(self.y_values, indices, fill_value=fill_value, allow_fill=allow_fill) | |
return VectorArray(x_result, y_result) | |
def copy(self): | |
""" | |
Return copy of array | |
""" | |
return VectorArray(np.copy(self.x_values), np.copy(self.y_values)) | |
def magnitude(self): | |
""" | |
Return array of magnitude values for vectors. | |
""" | |
# Implement using NumPy vectorised functions for efficiency | |
return np.sqrt(np.square(self.x_values) + np.square(self.y_values)) | |
def dot(self, other): | |
""" | |
Calculate dot product with single Vector or VectorArray of same length | |
""" | |
if isinstance(other, Vector): | |
# Dot product with single Vector | |
return self.x_values*other.x + self.y_values*other.y | |
elif isinstance(other, VectorArray) and self.size == other.size: | |
# Element-wise dot product with other VectorArray | |
return self.x_values*other.x_values + self.y_values*other.y_values | |
else: | |
raise TypeError('Cannot perform dot product with {}'.format(other)) | |
# Register operator overloads using logic defined in Vector class | |
VectorArray._add_arithmetic_ops() | |
VectorArray._add_comparison_ops() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment