Last active
September 7, 2024 20:23
-
-
Save qexat/6b04fc28146feabcbe18e1190371607b to your computer and use it in GitHub Desktop.
make pyright cook your CPU speedrun any%
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ruff: noqa: DOC201, DOC501 | |
""" | |
Combinatorial arithmetic | |
""" | |
from __future__ import annotations | |
import abc | |
import typing | |
import attrs | |
class TypeVisitor[R_co](typing.Protocol): | |
""" | |
Represents a visitor of the Type tree. | |
""" | |
def visit_zero_type(self, typ: Zero) -> R_co: | |
""" | |
Visit the Zero type. | |
""" | |
... | |
def visit_unit_type(self, typ: Unit) -> R_co: | |
""" | |
Visit the Unit type. | |
""" | |
... | |
def visit_atomic_type(self, typ: Atomic[typing.LiteralString]) -> R_co: | |
""" | |
Visit an Atomic type. | |
""" | |
... | |
def visit_negation_type(self, typ: Negation[Type]) -> R_co: | |
""" | |
Visit a Negation type. | |
""" | |
... | |
def visit_product_type(self, typ: Product[Type, Type]) -> R_co: | |
""" | |
Visit a Product type. | |
""" | |
... | |
def visit_sum_type(self, typ: Sum[Type, Type]) -> R_co: | |
""" | |
Visit a Sum type. | |
""" | |
... | |
@attrs.frozen | |
class TypeBase(abc.ABC): | |
""" | |
Base of the type tree. | |
""" | |
def __neg__[T: Type](self: T) -> Negation[T]: # pyright: ignore[reportGeneralTypeIssues] | |
return Negation(self) | |
def __add__[T0: Type, T1: Type](self: T0, other: T1, /) -> Sum[T0, T1]: # pyright: ignore[reportGeneralTypeIssues] | |
return Sum(self, other) | |
def __mul__[T0: Type, T1: Type](self: T0, other: T1, /) -> Product[T0, T1]: # pyright: ignore[reportGeneralTypeIssues] | |
return Product(self, other) | |
@abc.abstractmethod | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
""" | |
Accept of a type visitor and return the result. | |
""" | |
@attrs.frozen(init=False) | |
@typing.final | |
class Zero(TypeBase): | |
""" | |
Represents the Zero type. | |
""" | |
def __init__(self) -> None: | |
message = "the zero type is uninstantiable" | |
raise TypeError(message) | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_zero_type(self) | |
# TODO: make it a singleton | |
@attrs.frozen | |
@typing.final | |
class Unit(TypeBase): | |
""" | |
Represents the Unit type. | |
""" | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_unit_type(self) | |
@attrs.frozen | |
@typing.final | |
class Atomic[Name: typing.LiteralString](TypeBase): | |
""" | |
Represents an Atomic type. | |
""" | |
name: Name | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_atomic_type(self) | |
@attrs.frozen | |
@typing.final | |
class Negation[T: Type](TypeBase): | |
""" | |
Represents a Negation type. | |
""" | |
typ: T | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_negation_type(self) | |
@attrs.frozen | |
@typing.final | |
class Product[T0: Type, T1: Type](TypeBase): | |
""" | |
Represents a Product type. | |
""" | |
first: T0 | |
second: T1 | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_product_type( | |
typing.cast(Product[Type, Type], self), | |
) | |
@attrs.frozen | |
@typing.final | |
class Sum[T0: Type, T1: Type](TypeBase): | |
""" | |
Represents a Sum type. | |
""" | |
first: T0 | |
second: T1 | |
@typing.override | |
def accept[R](self, visitor: TypeVisitor[R]) -> R: | |
return visitor.visit_sum_type(typing.cast(Sum[Type, Type], self)) | |
type Type = ( | |
Zero | |
| Unit | |
| Atomic[typing.LiteralString] | |
| Negation["Type"] | |
| Product["Type", "Type"] | |
| Sum["Type", "Type"] | |
) | |
def _unsafe_zero() -> Zero: | |
""" | |
Instantiate a Zero through unsound means. | |
Be EXTREMELY CAREFUL with using this instance. | |
Returns | |
------- | |
Zero | |
""" | |
return TypeBase.__new__(Zero) | |
def identity[T: Type](typ: T) -> T: | |
""" | |
Identity combinator. | |
""" | |
return typ | |
def sum_id_left_intro[T: Type](typ: T) -> Sum[Zero, T]: | |
""" | |
Rewriting rule to introduce the sum left identity. | |
""" | |
return _unsafe_zero() + typ | |
def sum_id_left_elim[T: Type](typ: Sum[Zero, T]) -> T: | |
""" | |
Rewriting rule to eliminate the sum left identity. | |
""" | |
match typ: | |
case Sum(Zero(), a): | |
return a | |
def sum_comm[T0: Type, T1: Type](typ: Sum[T0, T1]) -> Sum[T1, T0]: | |
""" | |
Commutativity sum combinator. | |
""" | |
match typ: | |
case Sum(a, b): | |
return b + a | |
def sum_assoc_left[T0: Type, T1: Type, T2: Type]( | |
typ: Sum[T0, Sum[T1, T2]], | |
) -> Sum[Sum[T0, T1], T2]: | |
""" | |
Left-associativity sum combinator. | |
""" | |
match typ: | |
case Sum(a, Sum(b, c)): | |
return (a + b) + c | |
def sum_assoc_right[T0: Type, T1: Type, T2: Type]( | |
typ: Sum[Sum[T0, T1], T2], | |
) -> Sum[T0, Sum[T1, T2]]: | |
""" | |
Right-associativity sum combinator. | |
""" | |
match typ: | |
case Sum(Sum(a, b), c): | |
return a + (b + c) | |
def sum_eta[T: Type](typ: Zero) -> Sum[T, Negation[T]]: # noqa: ARG001 | |
""" | |
Eta sum combinator. | |
""" | |
raise RuntimeError("unreachable") | |
def sum_eps[T: Type](typ: Sum[T, Negation[T]]) -> Zero: # noqa: ARG001 | |
""" | |
Eps sum combinator. | |
""" | |
return _unsafe_zero() | |
def prod_id_left_intro[T: Type](typ: T) -> Product[Unit, T]: | |
""" | |
Rewriting rule to introduce the product left identity. | |
""" | |
return Unit() * typ | |
def prod_id_left_elim[T: Type](typ: Product[Unit, T]) -> T: | |
""" | |
Rewriting rule to eliminate the product left identity. | |
""" | |
match typ: | |
case Product(Unit(), a): | |
return a | |
def prod_comm[T0: Type, T1: Type](typ: Product[T0, T1]) -> Product[T1, T0]: | |
""" | |
Commutativity product combinator. | |
""" | |
match typ: | |
case Product(a, b): | |
return b * a | |
def prod_assoc_left[T0: Type, T1: Type, T2: Type]( | |
typ: Product[T0, Product[T1, T2]], | |
) -> Product[Product[T0, T1], T2]: | |
""" | |
Left-associative product combinator. | |
""" | |
match typ: | |
case Product(a, Product(b, c)): | |
return (a * b) * c | |
def prod_assoc_right[T0: Type, T1: Type, T2: Type]( | |
typ: Product[Product[T0, T1], T2], | |
) -> Product[T0, Product[T1, T2]]: | |
""" | |
Right-associative product combinator. | |
""" | |
match typ: | |
case Product(Product(a, b), c): | |
return a * (b * c) | |
def prod_zero_left_intro[T: Type](typ: Zero) -> Product[Zero, T]: # noqa: ARG001 | |
""" | |
Rewriting rule to introduce a left-zero-product term. | |
""" | |
raise RuntimeError("unreachable") | |
def prod_zero_left_elim[T: Type](typ: Product[Zero, T]) -> Zero: # noqa: ARG001 | |
""" | |
Rewriting rule to eliminate a left-zero-product term. | |
""" | |
return Zero() | |
def sum_distrib_prod_left[T0: Type, T1: Type, T2: Type]( | |
typ: Product[Sum[T0, T1], T2], | |
) -> Sum[Product[T0, T2], Product[T1, T2]]: | |
""" | |
Rewriting rule to distribute a sum product. | |
""" | |
match typ: | |
case Product(Sum(a, b), c): | |
return (a * c) + (b * c) | |
class PropVisitor[R_co](typing.Protocol): | |
""" | |
Represents a visitor of the proposition tree. | |
""" | |
def visit_equality(self, prop: Equality[Type, Type]) -> R_co: | |
""" | |
Visit an equality proposition. | |
""" | |
... | |
def visit_conjunction(self, prop: Conjunction[Prop, Prop]) -> R_co: | |
""" | |
Visit a proposition conjunction. | |
""" | |
... | |
def visit_disjunction(self, prop: Disjunction[Prop, Prop]) -> R_co: | |
""" | |
Visit a proposition disjunction. | |
""" | |
... | |
def visit_implication(self, prop: Implication[Prop, Prop]) -> R_co: | |
""" | |
Visit a proposition implication. | |
""" | |
... | |
@attrs.frozen | |
class PropBase(abc.ABC): | |
""" | |
Base of the proposition tree. | |
""" | |
def __and__[P0: Prop, P1: Prop]( | |
self: P0, # pyright: ignore[reportGeneralTypeIssues] | |
other: P1, | |
/, | |
) -> Conjunction[P0, P1]: | |
return Conjunction(self, other) | |
def __or__[P0: Prop, P1: Prop]( | |
self: P0, # pyright: ignore[reportGeneralTypeIssues] | |
other: P1, | |
/, | |
) -> Disjunction[P0, P1]: | |
return Disjunction(self, other) | |
def __rshift__[P0: Prop, P1: Prop]( | |
self: P0, # pyright: ignore[reportGeneralTypeIssues] | |
other: P1, | |
/, | |
) -> Implication[P0, P1]: | |
return Implication(self, other) | |
@abc.abstractmethod | |
def accept[R](self, visitor: PropVisitor[R]) -> R: | |
""" | |
Accept a proposition visitor and return the result. | |
""" | |
@attrs.frozen | |
class Equality[T0: Type, T1: Type](PropBase): | |
""" | |
Represents an equality. | |
""" | |
left: T0 | |
right: T1 | |
@typing.override | |
def accept[R](self, visitor: PropVisitor[R]) -> R: | |
return visitor.visit_equality(self) | |
@attrs.frozen | |
class Conjunction[P0: Prop, P1: Prop](PropBase): | |
""" | |
Represents a proposition conjunction. | |
""" | |
left: P0 | |
right: P1 | |
@typing.override | |
def accept[R](self, visitor: PropVisitor[R]) -> R: | |
return visitor.visit_conjunction(self) | |
@attrs.frozen | |
class Disjunction[P0: Prop, P1: Prop](PropBase): | |
""" | |
Represents a proposition disjunction. | |
""" | |
left: P0 | |
right: P1 | |
@typing.override | |
def accept[R](self, visitor: PropVisitor[R]) -> R: | |
return visitor.visit_disjunction( | |
typing.cast(Disjunction[Prop, Prop], self), | |
) | |
@attrs.frozen | |
class Implication[P0: Prop, P1: Prop](PropBase): | |
""" | |
Represents a proposition implication. | |
""" | |
left: P0 | |
right: P1 | |
@typing.override | |
def accept[R](self, visitor: PropVisitor[R]) -> R: | |
return visitor.visit_implication( | |
typing.cast(Implication[Prop, Prop], self), | |
) | |
type Prop = ( | |
Equality[Type, Type] | |
| Conjunction[Prop, Prop] | |
| Disjunction[Prop, Prop] | |
| Implication[Prop, Prop] | |
) | |
def sum_eq_inj[T0: Type, T1: Type, T2: Type, T3: Type]( | |
left: Sum[T0, T1], | |
right: Sum[T2, T3], | |
) -> Implication[ | |
Conjunction[Equality[T0, T2], Equality[T1, T3]], | |
Equality[Sum[T0, T1], Sum[T2, T3]], | |
]: | |
r""" | |
Sum equality injection. | |
∀ a, b, c, d : Type, a = c /\ b = d -> a + b = c + d | |
""" | |
match (left, right): | |
case Sum(a, b), Sum(c, d): | |
return Implication( | |
Conjunction(Equality(a, c), Equality(b, d)), | |
Equality(left, right), | |
) | |
def main() -> None: | |
""" | |
Entry point of the program. | |
""" | |
foo = Atomic("foo") | |
zero = sum_eps(foo + (-foo)) | |
typing.reveal_type(zero) # Revealed type is "Zero" | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment