Last active
January 20, 2023 19:16
-
-
Save blink1073/289ef971393a2b0b1c3035d64c2dfe66 to your computer and use it in GitHub Desktop.
Pandas Extension Types for BSON
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from bson import ObjectId, Decimal128, Binary | |
from pandas.api.extensions import ExtensionDtype, ExtensionArray, register_extension_dtype | |
from pandas._typing import type_t | |
import pyarrow as pa | |
from typing import Union, Any | |
import numpy as np | |
import pandas as pd | |
import numbers | |
class ObjectIdScalar(pa.ExtensionScalar): | |
def as_py(self): | |
return ObjectId(self.value.as_py()) | |
class ObjectIdType(pa.PyExtensionType): | |
def __init__(self): | |
super().__init__(pa.binary(12)) | |
def __reduce__(self): | |
return ObjectIdType, () | |
def to_pandas_dtype(self): | |
return PandasObjectId() | |
def __arrow_ext_scalar_class__(self): | |
return ObjectIdScalar | |
class BSONDtype(ExtensionDtype): | |
na_value = np.nan | |
@classmethod | |
@property | |
def name(cls) -> str: | |
return f'bson_{cls.type}' | |
def __from_arrow__( | |
self, array: Union[pa.Array, pa.ChunkedArray] | |
) -> ExtensionArray: | |
if isinstance(array, pa.Array): | |
chunks = [array] | |
else: | |
# pyarrow.ChunkedArray | |
chunks = array.chunks | |
arr_type = self.construct_array_type() | |
results = [] | |
for arr in chunks: | |
# Convert low level values to the desired type. | |
vals = [] | |
typ = self.type | |
for val in np.array(arr): | |
if not pd.isna(val) and not isinstance(val, typ): | |
val = typ(val) | |
vals.append(val) | |
arr = np.array(vals) | |
# using _from_sequence to ensure None is converted to NA | |
to_append = arr_type._from_sequence(arr) | |
results.append(to_append) | |
if results: | |
return arr_type._concat_same_type(results) | |
else: | |
return arr_type(np.array([], dtype="object")) | |
class BaseExtensionArray(ExtensionArray): | |
def __init__(self, values, dtype=None, copy=False) -> None: | |
if not isinstance(values, np.ndarray): | |
raise TypeError("Need to pass a numpy array as values") | |
for val in values: | |
if not isinstance(val, self.dtype.type) and not pd.isna(val): | |
raise ValueError(f'Values must be either {self.dtype.type} or NA') | |
self.data = values | |
@classmethod | |
def _from_sequence(cls, scalars, dtype=None, copy=False): | |
data = np.empty(len(scalars), dtype=object) | |
data[:] = scalars | |
return cls(data) | |
@classmethod | |
def _from_factorized(cls, values, original): | |
return cls(values, dtype=original.dtype) | |
def __getitem__(self, item): | |
if isinstance(item, numbers.Integral): | |
return self.data[item] | |
else: | |
# slice, list-like, mask | |
item = pd.api.indexers.check_array_indexer(self, item) | |
return type(self)(self.data[item]) | |
def __setitem__(self, item, value): | |
if not hasattr(value, '__iter__') and not isinstance(value, self.dtype.type) and not pd.isna(value): | |
raise ValueError(f'Value must be of type {self.dtype.type} or nan') | |
if not isinstance(item, numbers.Integral): | |
# slice, list-like, mask | |
item = pd.api.indexers.check_array_indexer(self, item) | |
elif not isinstance(value, self.dtype.type) and not pd.isna(value): | |
raise ValueError(f'Array element must be of type {self.dtype.type} or nan') | |
self.data[item] = value | |
def __len__(self) -> int: | |
return len(self.data) | |
def isna(self): | |
return np.array( | |
[not isinstance(x, self.dtype.type) and np.isnan(x) for x in self.data], dtype=bool | |
) | |
def __eq__(self, other): | |
return self.data == other | |
def nbytes(self): | |
return self.data.nbytes | |
def take(self, indexer, allow_fill=False, fill_value=None): | |
# re-implement here, since NumPy has trouble setting | |
# sized objects like UserDicts into scalar slots of | |
# an ndarary. | |
indexer = np.asarray(indexer) | |
msg = ( | |
"Index is out of bounds or cannot do a " | |
"non-empty take from an empty array." | |
) | |
if allow_fill: | |
if fill_value is None: | |
fill_value = self.dtype.na_value | |
# bounds check | |
if (indexer < -1).any(): | |
raise ValueError | |
try: | |
output = [ | |
self.data[loc] if loc != -1 else fill_value for loc in indexer | |
] | |
except IndexError as err: | |
raise IndexError(msg) from err | |
else: | |
try: | |
output = [self.data[loc] for loc in indexer] | |
except IndexError as err: | |
raise IndexError(msg) from err | |
return self._from_sequence(output) | |
def copy(self): | |
return type(self)(self.data.copy()) | |
@classmethod | |
def _concat_same_type(cls, to_concat): | |
data = np.concatenate([x.data for x in to_concat]) | |
return cls(data) | |
@register_extension_dtype | |
class PandasObjectId(BSONDtype): | |
type = ObjectId | |
@classmethod | |
def construct_array_type(cls) -> type_t[PandasObjectIdArray]: | |
return PandasObjectIdArray | |
@register_extension_dtype | |
class PandasDecimal128(BSONDtype): | |
type = Decimal128 | |
@classmethod | |
def construct_array_type(cls) -> type_t[PandasDecimal128Array]: | |
return PandasDecimal128Array | |
@register_extension_dtype | |
class PandasBinary(BSONDtype): | |
type = Binary | |
@classmethod | |
def construct_array_type(cls) -> type_t[PandasBinaryArray]: | |
return PandasBinaryArray | |
class PandasObjectIdArray(BaseExtensionArray): | |
dtype = PandasObjectId() | |
class PandasDecimal128Array(BaseExtensionArray): | |
dtype = PandasDecimal128() | |
class PandasBinaryArray(BaseExtensionArray): | |
dtype = PandasBinary() | |
def __array__(self, dtype=None): | |
return np.array(self.data, dtype) | |
def __eq__(self, other): | |
return np.array([a == other for a in self.data]) | |
def __contains__(self, item: object) -> bool | np.bool_: | |
if pd.isna(item): | |
if not isinstance(item, float): | |
return False | |
return np.any([pd.isna(a) for a in self.data]) | |
return np.any([a == item for a in self.data]) | |
from pandas.tests.extension import base | |
import pytest | |
# def make_data(): | |
# return [ObjectId() for _ in range(8)] + [np.nan] + [ObjectId() for _ in range(88)] + [np.nan] + [ObjectId(), ObjectId()] | |
# @pytest.fixture | |
# def dtype(): | |
# return PandasObjectId() | |
def make_datum(): | |
value = np.random.rand() | |
return Binary(str(value).encode('utf8'), 10) | |
def make_data(): | |
return [make_datum() for _ in range(8)] + [np.nan] + [make_datum() for _ in range(88)] + [np.nan] + [make_datum(), make_datum()] | |
@pytest.fixture | |
def dtype(): | |
return PandasBinary() | |
@pytest.fixture | |
def data(dtype): | |
return pd.array(make_data(), dtype=dtype) | |
@pytest.fixture | |
def data_for_twos(dtype): | |
return pd.array(np.ones(100), dtype=dtype) | |
@pytest.fixture | |
def data_missing(dtype): | |
return pd.array([np.nan, make_datum()], dtype=dtype) | |
@pytest.fixture | |
def data_for_sorting(dtype): | |
return pd.array([make_datum(), make_datum(), make_datum()], dtype=dtype) | |
@pytest.fixture | |
def data_missing_for_sorting(dtype): | |
return pd.array([make_datum(), np.nan, make_datum()], dtype=dtype) | |
@pytest.fixture | |
def na_value(): | |
return np.nan | |
@pytest.fixture | |
def na_cmp(): | |
def cmp(a, b): | |
return np.isnan(a) and np.isnan(b) | |
return cmp | |
@pytest.fixture(params=[True, False]) | |
def box_in_series(request): | |
"""Whether to box the data in a Series""" | |
return request.param | |
@pytest.fixture(params=[True, False]) | |
def as_array(request): | |
""" | |
Boolean fixture to support ExtensionDtype _from_sequence method testing. | |
""" | |
return request.param | |
@pytest.fixture(params=["ffill", "bfill"]) | |
def fillna_method(request): | |
""" | |
Parametrized fixture giving method parameters 'ffill' and 'bfill' for | |
Series.fillna(method=<method>) testing. | |
""" | |
return request.param | |
@pytest.fixture | |
def invalid_scalar(data): | |
""" | |
A scalar that *cannot* be held by this ExtensionArray. | |
The default should work for most subclasses, but is not guaranteed. | |
If the array can hold any item (i.e. object dtype), then use pytest.skip. | |
""" | |
return object.__new__(object) | |
class TestDtype(base.BaseDtypeTests): | |
pass | |
class TestInterface(base.BaseInterfaceTests): | |
pass | |
class TestConstructors(base.BaseConstructorsTests): | |
pass | |
class TestGetitem(base.BaseGetitemTests): | |
pass | |
class TestSetitem(base.BaseSetitemTests): | |
pass | |
class TestIndex(base.BaseIndexTests): | |
pass | |
class TestMissing(base.BaseMissingTests): | |
pass | |
def test_to_pandas(): | |
schema = pa.schema([("data", ObjectIdType())]) | |
table = pa.Table.from_pydict( | |
{ "data": [ObjectId().binary, ObjectId().binary, ObjectId().binary, None]}, | |
schema=schema) | |
df = table.to_pandas() | |
import pdb; pdb.set_trace() | |
pass | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Rollout plan:
Binary
extension typesObjectId
pandas extension typeDecimal128
asbinary(16)
,as_py
asDecimal128
and PandasDecimal128
and deprecateDecimal128Str
DBRef
,Code
,Int64
,MaxKey
,MinKey
,Regex
,Timestamp
ArrowExtensionArray
as the base class once it becomes stable