Last active
September 14, 2024 22:54
-
-
Save betafcc/a5d97a89a9f50a1efb4000481d6b9729 to your computer and use it in GitHub Desktop.
Extensible typed records in python (pyright)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
from typing import Any, ClassVar, Generic, Literal, NoReturn, Type, TypeAlias, TypeVar | |
from typing_extensions import LiteralString | |
K = TypeVar("K", bound=LiteralString) | |
V = TypeVar("V", covariant=True) | |
AnyPair: TypeAlias = tuple[LiteralString, Any] | |
M = TypeVar("M", bound=AnyPair, covariant=True) | |
N = TypeVar("N", bound=AnyPair, covariant=True) | |
@dataclass(frozen=True) | |
class Record(Generic[M]): | |
empty: ClassVar[Record[NoReturn]] | |
dict: dict[Any, Any] | |
@staticmethod | |
def get(key: K) -> Get[K]: | |
return Get(key) | |
def __and__(self, other: Record[N]) -> Record[M | N]: | |
return Record(self.dict | other.dict) | |
def __add__(self, other: tuple[K, V]) -> Record[M | tuple[K, V]]: | |
return Record(self.dict | {other[0]: other[1]}) | |
Record.empty = Record[NoReturn]({}) | |
@dataclass(frozen=True) | |
class Get(Generic[K]): | |
key: K | |
def __call__(self, record: Record[tuple[K, V] | AnyPair]) -> V: | |
return record.dict[self.key] | |
# Extra type-level helpers below | |
# fmt: off | |
class MergeMeta(type): | |
def __getitem__(cls, args: tuple[Type[Record[M]], Type[Record[N]]]) -> Type[Record[M | N]]: ... | |
class KeyOfMeta(type): | |
def __getitem__(cls, args: Type[Record[tuple[K, Any]]]) -> Type[K]: ... | |
class Merge(metaclass=MergeMeta): ... | |
class KeyOf(metaclass=KeyOfMeta): ... | |
# fmt: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
result = Record.empty + ("a", 1) + ("b", "hi") + ("c", 3.14) | |
# (variable) result: Record[tuple[Literal['a'], Literal[1]] | tuple[Literal['b'], Literal['hi']] | tuple[Literal['c'], float]] | |
Record.get("a")(result) # Revealed type is 'Literal[1]' | |
Record.get("b")(result) # Revealed type is 'Literal['hi']' | |
Record.get("c")(result) # Revealed type is 'float' | |
merged = result & (Record.empty + ("d", True) + ("e", 2)) | |
# (variable) merged: Record[tuple[Literal['a'], Literal[1]] | tuple[Literal['b'], Literal['hi']] | tuple[Literal['c'], float] | tuple[Literal['d'], Literal[True]] | tuple[Literal['e'], Literal[2]]] | |
Record.get("a")(merged) # Revealed type is 'Literal[1]' | |
Record.get("b")(merged) # Revealed type is 'Literal['hi']' | |
Record.get("c")(merged) # Revealed type is 'float' | |
Record.get("d")(merged) # Revealed type is 'Literal[True]' | |
Record.get("e")(merged) # Revealed type is 'Literal[2]' | |
Merged: TypeAlias = Merge[ | |
Record[tuple[Literal["a"], int] | tuple[Literal["b"], str]], | |
Record[tuple[Literal["c"], float] | tuple[Literal["d"], bool]], | |
] | |
# (type alias) Merged: Type[Record[tuple[Literal['a'], int] | tuple[Literal['b'], str] | tuple[Literal['c'], float] | tuple[Literal['d'], bool]]] | |
Keys: TypeAlias = KeyOf[Merged] | |
# (type alias) Keys: Type[Literal['a', 'b', 'c', 'd']] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The idea is to use a covariant Generic Var to keep track of the type of 'item' in the record, that will be a union of all the added tuple pairs.
After that, since the generic is covariant, we can pattern match on a union with 'AnyPair' to retrieve the value from key.
We do have to manipulate the TypeVar binding order from pyright for this to work tho, that's why the
get
function is curried by a generic class, and I couldn't find a way to make it into a method onRecord
class