Created
November 16, 2022 14:41
-
-
Save qexat/2b0d845906779183a1ff4f40bdff63b9 to your computer and use it in GitHub Desktop.
the fact that it passes mypy and pyright completely lmao
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass, field | |
from typing import Generic, Literal, overload, TypeGuard, TypeVar | |
T = TypeVar("T") | |
HasPopd = TypeVar("HasPopd", bound=bool) | |
@dataclass | |
class _Stack(Generic[T, HasPopd]): | |
__internal: list[T] = field(default_factory=list, init=False) | |
__popd: list[T] = field(default_factory=list, init=False) | |
@overload | |
def push(self, element: T, *, chain: Literal[True]) -> _Stack[T, HasPopd]: | |
... | |
@overload | |
def push(self, element: T, *, chain: Literal[False] = False) -> None: | |
... | |
def push(self, element: T, *, chain: bool = False) -> None | _Stack[T, HasPopd]: | |
self.__internal.append(element) | |
if chain: | |
return self | |
return None | |
@overload | |
def pop(self, *, chain: Literal[True]) -> _Stack[T, Literal[True]]: | |
... | |
@overload | |
def pop( | |
self: _Stack[T, Literal[True]], *, chain: Literal[False] = False | |
) -> list[T]: | |
... | |
@overload | |
def pop(self: _Stack[T, Literal[False]], *, chain: Literal[False] = False) -> T: | |
... | |
def pop(self, *, chain: bool = False) -> T | list[T] | _Stack[T, Literal[True]]: | |
value = self.__internal.pop() | |
self.__popd.append(value) | |
if has_chain(self, chain): | |
return self | |
if len(self.__popd) > 1: | |
popped = self.__popd.copy() | |
self.__popd = [] | |
return popped | |
self.__popd.pop() | |
return value | |
def has_chain( | |
stack: _Stack[T, HasPopd], chain: bool | |
) -> TypeGuard[_Stack[T, Literal[True]]]: | |
return chain | |
Stack = _Stack[T, Literal[False]] | |
def main() -> None: | |
stack = Stack[int]() | |
value = stack.push(3, chain=True).pop() | |
value2 = ( | |
stack.push(3, chain=True) | |
.push(2, chain=True) | |
.push(1, chain=True) | |
.pop(chain=True) | |
.pop(chain=True) | |
.pop() | |
) | |
print(f"{value=} / {value2=}") # value=3 / value2=[1, 2, 3] | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment