Created
October 4, 2023 12:05
-
-
Save cqfd/daabbd0eadfd16a4cf43622b56789886 to your computer and use it in GitHub Desktop.
click but with nice types
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from abc import ABC, abstractmethod | |
from contextlib import contextmanager, nullcontext | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
ContextManager, | |
Generic, | |
Iterator, | |
Self, | |
TypeVar, | |
overload, | |
) | |
import click | |
T = TypeVar("T", covariant=True) | |
U = TypeVar("U", covariant=True) | |
@click.command() | |
@click.option("--count", default=1, help="The number of greetings to do") | |
@click.option("--name", prompt="Your name", help="The person to greet.") | |
def say_hello(name: str, count: int) -> None: | |
print(f"hi {name}!") | |
class Parameter(ABC, Generic[T]): | |
_name: str | None | |
_callback: Callable[[Any, Any, Any], T] | |
def __init__(self) -> None: | |
super().__init__() | |
self._name = None | |
self._callback = lambda _ctx, _param, x: x | |
@abstractmethod | |
def to_click_decorator(self) -> Callable[..., Any]: | |
... | |
def callback( | |
self, f: Callable[[click.Context, click.Parameter, T], U] | |
) -> "Parameter[U]": | |
u = copy.copy(self) | |
u._callback = lambda ctx, param, t: f(ctx, param, self._callback(ctx, param, t)) # type: ignore | |
return u # type: ignore | |
def validate(self, validator: Callable[[T], U]) -> "Parameter[U]": | |
return self.callback(lambda ctx, param, t: validator(t)) # type: ignore | |
def map(self, f: Callable[[T], U]) -> "Parameter[U]": | |
return self.validate(f) | |
def error_if(self, condition: Callable[[T], bool], msg: str) -> "Parameter[T]": | |
def v(t: T) -> T: # type: ignore | |
if condition(t): | |
raise click.BadParameter(msg) | |
else: | |
return t | |
return self.validate(v) | |
def map_resourcefully( | |
self, resource: Callable[[T], ContextManager[U]] | |
) -> "Parameter[U]": | |
return self.callback(lambda ctx, _param, t: ctx.with_resource(resource(t))) | |
@overload | |
def __call__(self) -> Callable[[type["SomeCommand"]], type["SomeCommand"]]: | |
... | |
@overload | |
def __call__(self, cls: type["SomeCommand"]) -> type["SomeCommand"]: | |
... | |
def __call__( | |
self, cls: type["SomeCommand"] | None = None | |
) -> type["SomeCommand"] | Callable[[type["SomeCommand"]], type["SomeCommand"]]: | |
def decorator(cls: type[SomeCommand]) -> type[SomeCommand]: | |
cls.command = self.to_click_decorator()(cls.command) | |
return cls | |
if cls is None: | |
return decorator | |
return decorator(cls) | |
if TYPE_CHECKING: | |
@overload | |
def __get__(self, instance: None, owner: type["Command"]) -> Self: | |
... | |
@overload | |
def __get__(self, instance: "Command", owner: type["Command"]) -> T: | |
... | |
def __get__(self, instance: Any, owner: Any) -> T | Self: | |
... | |
class Argument(Parameter[T]): | |
def __init__(self): | |
pass | |
class Option(Parameter[T]): | |
_name: str | None | |
_variable_name: str | None | |
@overload | |
def __init__( | |
self: "Option[str | None]", | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
) -> None: | |
... | |
@overload | |
def __init__( | |
self: "Option[T]", | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
default: T, | |
) -> None: | |
... | |
def __init__( | |
self, | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
default: T | None = None, | |
) -> None: | |
super().__init__() | |
self._name = name | |
self._variable_name = None | |
self.help = help | |
self.default = default | |
self.expose_value = expose_value | |
if TYPE_CHECKING: | |
@overload | |
def __new__( | |
cls, | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
) -> "Option[str | None]": | |
... | |
@overload | |
def __new__( | |
cls, | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
default: T, | |
) -> "Option[T]": | |
... | |
def __new__( | |
cls, | |
*, | |
name: str | None = None, | |
help: str | None = None, | |
expose_value: bool = True, | |
default: T | None = ..., | |
) -> "Option[T | str | None]": | |
... | |
def to_click_decorator(self) -> Callable[..., Any]: | |
name = self._name or self._variable_name | |
return click.option( | |
f"--{name}", | |
help=self.help, | |
default=self.default, | |
expose_value=self.expose_value, | |
callback=self._callback, | |
) | |
def __set_name__(self, owner: Any, name: str) -> None: | |
self._variable_name = name | |
class Command(ABC): | |
command: Callable[[], None] | |
def __init_subclass__(cls, **kwargs): | |
super().__init_subclass__(**kwargs) | |
name = kwargs.get("name", cls.__name__.lower()) | |
parameters = {k: v for k, v in cls.__dict__.items() if isinstance(v, Parameter)} | |
def _run(**kwargs: Any): | |
instance = cls() | |
for k, p in parameters.items(): | |
setattr(instance, k, kwargs[p._name or k]) | |
return instance.run() | |
for p in reversed(parameters.values()): | |
_run = p.to_click_decorator()(_run) | |
_run = click.command(name=name)(_run) | |
cls.command = _run | |
@abstractmethod | |
def run(self) -> None: | |
... | |
SomeCommand = TypeVar("SomeCommand", bound=Command) | |
@contextmanager | |
def runin(location: str) -> Iterator[str]: | |
print(f"Traveling to {location=}") | |
yield location | |
print(f"Leaving {location=}") | |
runin_capable = Option( | |
name="runin", help="run in a location", default="local" | |
).map_resourcefully( | |
lambda location: runin(location) if location != "local" else nullcontext(location) | |
) | |
@runin_capable | |
class Example(Command): | |
name = ( | |
Option(help="enter your name") | |
.map(lambda name: name or "") | |
.error_if(lambda x: len(x) < 5, "Too short") | |
.error_if(lambda x: 10 < len(x), "Too long") | |
) | |
age = Option(help="enter your age", default=1).error_if( | |
lambda age: age < 18, "Too young" | |
) | |
def run(self) -> None: | |
print(self.name) | |
print(self.age) | |
if __name__ == "__main__": | |
Example.command() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment