Last active
January 16, 2024 15:45
-
-
Save mdellavo/443d206b1800887b0b64cad92f734472 to your computer and use it in GitHub Desktop.
WIP: Document mapper for FoundationDB
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
import datetime | |
import itertools | |
import struct | |
from typing import TypeVar, Optional, Any, Type, TypedDict | |
import fdb | |
from fdb.impl import Transaction | |
from fdb.subspace_impl import Subspace | |
import varint | |
fdb.api_version(710) | |
TYPE_REGISTRY = {} | |
class Field: | |
TYPE = bytes | |
def __init_subclass__(cls, **kwargs): | |
TYPE_REGISTRY[cls.TYPE] = cls | |
def __init__(self, key=None): | |
self.key = key | |
@classmethod | |
def encode(cls, value: bytes) -> bytes: | |
return value | |
@classmethod | |
def decode(cls, value: bytes) -> bytes: | |
return value | |
TYPE_REGISTRY[bytes] = Field | |
FieldType = TypeVar('FieldType', bound=Field) | |
class BoolField(Field): | |
TYPE = bool | |
@classmethod | |
def encode(cls, value: bool) -> bytes: | |
return b'\x01' if bool else b'\x00' | |
@classmethod | |
def decode(cls, value: bytes) -> bool: | |
return bool(value[0]) | |
class IntegerField(Field): | |
TYPE = int | |
@classmethod | |
def encode(cls, value: int) -> bytes: | |
return varint.encode(value) | |
@classmethod | |
def decode(cls, value: bytes) -> int: | |
return varint.decode_bytes(value) | |
class FloatField(Field): | |
TYPE = float | |
@classmethod | |
def encode(cls, value: float) -> bytes: | |
return struct.pack("d", value) | |
@classmethod | |
def decode(cls, value: bytes) -> float: | |
return struct.unpack("d", value)[0] | |
class StringField(Field): | |
TYPE = str | |
@classmethod | |
def encode(cls, value: str) -> bytes: | |
return value.encode("utf-8") | |
@classmethod | |
def decode(cls, value: bytes) -> str: | |
return value.decode("utf-8") | |
class DateTimeField(StringField): | |
TYPE = datetime.datetime | |
@classmethod | |
def encode(cls, value: datetime.datetime) -> bytes: | |
return StringField.encode(value.isoformat()) | |
@classmethod | |
def decode(cls, value: bytes) -> datetime.datetime: | |
return datetime.datetime.fromisoformat(StringField.decode(value)) | |
try: | |
import bson | |
class BsonField(Field): | |
TYPE = bson.ObjectId | |
@classmethod | |
def encode(cls, value: bson.ObjectId) -> bytes: | |
return value.binary | |
@classmethod | |
def decode(cls, value: bytes) -> bson.ObjectId: | |
return bson.ObjectId(value) | |
except ImportError: | |
pass | |
class Key(tuple): | |
def pack(self): | |
return fdb.tuple.pack(self) | |
def unpack(self): | |
return fdb.tuple.unpack(self) | |
def fdb_field(field: FieldType, *args, **kwargs) -> dataclasses.Field: | |
kwargs["metadata"] = kwargs.get("metadata", {}) | |
kwargs["metadata"]["fdb"] = field | |
return dataclasses.field(*args, **kwargs) | |
def get_fdb_field(field: FieldType) -> Optional[Field]: | |
f_field = TYPE_REGISTRY.get(field.type) | |
default = {"fdb": f_field() if f_field else None} | |
metadata = field.metadata or default | |
return metadata.get("fdb") | |
EMPTY_OBJECT = -2 | |
EMPTY_ARRAY = -1 | |
def to_tuples(item, encode=None): | |
if item == {}: | |
return [(EMPTY_OBJECT, None)] | |
elif item == []: | |
return [(EMPTY_ARRAY, None)] | |
elif isinstance(item, dict): | |
return [(k,) + sub for k, v in item.items() for sub in to_tuples(v)] | |
elif isinstance(item, (list, tuple)): | |
return [(k,) + sub for k, v in enumerate(item) for sub in to_tuples(v)] | |
elif dataclasses.is_dataclass(item): | |
rv = [] | |
fields = dataclasses.fields(item) | |
values = dataclasses.asdict(item) | |
for field in fields: | |
value = values[field.name] | |
f_field = get_fdb_field(field) | |
key = f_field.key if f_field and f_field.key else field.name | |
for sub in to_tuples(value, f_field.encode): | |
rv.append((key,) + sub) | |
return rv | |
else: | |
if not encode: | |
encode = TYPE_REGISTRY.get(type(item)).encode | |
if not encode: | |
raise ValueError(f"could not serialize {item}") | |
return [(encode(item),)] | |
def from_tuples(tuples: list[tuple]): | |
if not tuples: | |
return {} | |
first = tuples[0] | |
if len(first) == 1: | |
return first[0] | |
if first == (EMPTY_OBJECT, None): | |
return {} | |
if first == (EMPTY_ARRAY, None): | |
return [] | |
groups = [list(g) for k, g in itertools.groupby(tuples, lambda t: t[0])] | |
if first[0] == 0: | |
return [from_tuples([t[1:] for t in g]) for g in groups] | |
else: | |
return dict((g[0][0], from_tuples([t[1:] for t in g])) for g in groups) | |
def hydrate_document(storage: dict, doc_class: Any) -> Any: | |
if not dataclasses.is_dataclass(doc_class): | |
raise ValueError("doc is not a dataclass") | |
doc = {} | |
for field in dataclasses.fields(doc_class): | |
f_field = get_fdb_field(field) | |
key = f_field.key if f_field and f_field.key else field.name | |
if dataclasses.is_dataclass(field.type): | |
value = hydrate_document(storage[key], field.type) | |
else: | |
value = f_field.decode(storage[key]) | |
doc[field.name] = value | |
return doc_class(**doc) | |
@fdb.transactional | |
def store_document(tr: Transaction, space: Subspace, key: tuple, doc: Any): | |
if not dataclasses.is_dataclass(doc): | |
raise ValueError("doc is not a dataclass") | |
for row in to_tuples(doc): | |
k = space.pack(key + row[:-1]) | |
value = row[-1] | |
tr[k] = value | |
@fdb.transactional | |
def load_document(tr: Transaction, space: Subspace, key: tuple, doc_class: Type) -> Any: | |
tuples = [space.unpack(k)[1:] + (v,) for k, v in tr[space.range(key)]] | |
storage = from_tuples(tuples) | |
doc = hydrate_document(storage, doc_class) | |
return doc | |
if __name__ == "__main__": | |
@dataclasses.dataclass | |
class SubDoc: | |
foo: int | |
bar: str | |
baz: float | |
class SomeDict(TypedDict): | |
aaa: int | |
bbb: str | |
ccc: float | |
@dataclasses.dataclass | |
class Example: | |
raw_field: bytes = fdb_field(Field("r")) | |
bool_field: bool = fdb_field(BoolField("b")) | |
int_field: int = fdb_field(IntegerField("i")) | |
float_field: float = fdb_field(FloatField("f")) | |
str_field: str = fdb_field(StringField("t")) | |
datetime_field: datetime.datetime = fdb_field(DateTimeField("dt")) | |
doc_field: SubDoc = fdb_field(Field("sd")) | |
# FIXME | |
list_field: list[str] = fdb_field(Field("l")) | |
dict_field: SomeDict = fdb_field(Field("d")) | |
example_doc = Example( | |
raw_field=b'\xde\xad\xbe\xef', | |
bool_field=True, | |
int_field=42, | |
float_field=3.14, | |
str_field="hello world", | |
datetime_field=datetime.datetime.now(tz=datetime.timezone.utc), | |
doc_field=SubDoc(1, "hello", 3.14), | |
list_field=["a", "b", "c"], | |
dict_field={"aaa": 1, "bbb": "xxx", "ccc": 2.718}, | |
) | |
db = fdb.open() | |
doc_space = Subspace(('D',)) | |
store_document(db, doc_space, ("abc123",), example_doc) | |
test_doc = load_document(db, doc_space, ("abc123",), Example) | |
import pprint | |
pprint.pprint(example_doc) | |
pprint.pprint(test_doc) |
Current output (with broken list, dict fields)
❯ python document.py
Example(raw_field=b'\xde\xad\xbe\xef',
bool_field=True,
int_field=42,
float_field=3.14,
str_field='hello world',
datetime_field=datetime.datetime(2024, 1, 16, 4, 51, 32, 265293, tzinfo=datetime.timezone.utc),
doc_field=SubDoc(foo=1, bar='hello', baz=3.14),
list_field=['a', 'b', 'c'],
dict_field={'aaa': 1, 'bbb': 'xxx', 'ccc': 2.718})
Example(raw_field=b'\xde\xad\xbe\xef',
bool_field=True,
int_field=42,
float_field=3.14,
str_field='hello world',
datetime_field=datetime.datetime(2024, 1, 16, 4, 51, 32, 265293, tzinfo=datetime.timezone.utc),
doc_field=SubDoc(foo=1, bar='hello', baz=3.14),
list_field=[b'a', b'b', b'c'],
dict_field={'aaa': b'\x01',
'bbb': b'xxx',
'ccc': b'X9\xb4\xc8v\xbe\x05@'})
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Previous output showing example key structure