Skip to content

Instantly share code, notes, and snippets.

@kurtbrose
Last active December 31, 2024 02:11
Show Gist options
  • Save kurtbrose/a494f0b6dec03780eb9ee8e29ade804b to your computer and use it in GitHub Desktop.
Save kurtbrose/a494f0b6dec03780eb9ee8e29ade804b to your computer and use it in GitHub Desktop.
from typing import (
Protocol,
Callable,
TypeVar,
ParamSpec,
overload,
Concatenate,
runtime_checkable
)
P = ParamSpec("P")
T = TypeVar("T", contravariant=True) # param type variance
R = TypeVar("R", covariant=True) # return type variance
@runtime_checkable
class _MaybeWrapper(Protocol[T, P, R]):
@overload
def __call__(self, first: None, *args: P.args, **kwargs: P.kwargs) -> None: ...
@overload
def __call__(self, first: T, *args: P.args, **kwargs: P.kwargs) -> R: ...
def maybe(func: Callable[Concatenate[T, P], R]) -> _MaybeWrapper[T, P, R]:
"""
Add a short-circuit to a function with at least one positional argument, None returns None.
The fancy part: type checkers will understand what is going on.
Not just func(T | None) -> R | None, they will understand that a None-parameter yields a None-return,
and otherwise the same return and type parameters apply.
Example:
@maybe
def maybe_increment(n: int) -> int:
return n + 1
none_val: None = maybe_increment(None)
int_val: int = maybe_increment(1)
"""
@overload
def wrapper(first: None, *args: P.args, **kwargs: P.kwargs) -> None: ...
@overload
def wrapper(first: T, *args: P.args, **kwargs: P.kwargs) -> R: ...
def wrapper(first, *args, **kwargs):
if first is None:
return None
return func(first, *args, **kwargs)
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment