Last active
December 31, 2024 02:11
-
-
Save kurtbrose/a494f0b6dec03780eb9ee8e29ade804b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import ( | |
Protocol, | |
Callable, | |
TypeVar, | |
ParamSpec, | |
overload, | |
Concatenate, | |
runtime_checkable | |
) | |
P = ParamSpec("P") | |
T = TypeVar("T", contravariant=True) # param type variance | |
R = TypeVar("R", covariant=True) # return type variance | |
@runtime_checkable | |
class _MaybeWrapper(Protocol[T, P, R]): | |
@overload | |
def __call__(self, first: None, *args: P.args, **kwargs: P.kwargs) -> None: ... | |
@overload | |
def __call__(self, first: T, *args: P.args, **kwargs: P.kwargs) -> R: ... | |
def maybe(func: Callable[Concatenate[T, P], R]) -> _MaybeWrapper[T, P, R]: | |
""" | |
Add a short-circuit to a function with at least one positional argument, None returns None. | |
The fancy part: type checkers will understand what is going on. | |
Not just func(T | None) -> R | None, they will understand that a None-parameter yields a None-return, | |
and otherwise the same return and type parameters apply. | |
Example: | |
@maybe | |
def maybe_increment(n: int) -> int: | |
return n + 1 | |
none_val: None = maybe_increment(None) | |
int_val: int = maybe_increment(1) | |
""" | |
@overload | |
def wrapper(first: None, *args: P.args, **kwargs: P.kwargs) -> None: ... | |
@overload | |
def wrapper(first: T, *args: P.args, **kwargs: P.kwargs) -> R: ... | |
def wrapper(first, *args, **kwargs): | |
if first is None: | |
return None | |
return func(first, *args, **kwargs) | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment