Created
October 25, 2025 22:31
-
-
Save csvance/97a32099729ca6f51aed0625f9642392 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import * | |
| import abc | |
| class ArgsMatcher(abc.ABC): | |
| @abc.abstractmethod | |
| def __call__(self, *args, **kwargs): | |
| raise NotImplementedError() | |
| def strength(self): | |
| raise NotImplementedError() | |
| class ArgsMatcherValue(ArgsMatcher): | |
| def __init__(self, v: Any, strength: int = 999999): | |
| self.idx: Optional[int] = None | |
| self.v = v | |
| self._strength = strength | |
| def __call__(self, *args, **kwargs): | |
| return args[self.idx] == self.v | |
| def strength(self): | |
| return self._strength | |
| def __repr__(self): | |
| return f"ArgsMatcherValue({self.idx}, {self.v})" | |
| class ArgsMatcherType(ArgsMatcher): | |
| def __init__(self, t: Type): | |
| self.idx: Optional[int] = None | |
| self.t = t | |
| # The more specialized, the stronger the match | |
| self._strength = len(self.t.__mro__) | |
| def __call__(self, *args, **kwargs): | |
| return isinstance(args[self.idx], self.t) | |
| def strength(self): | |
| return self._strength | |
| def __repr__(self): | |
| return f"ArgsMatcherType({self.idx}, {self.t})" | |
| class ArgsMatcherExpr(ArgsMatcher): | |
| def __init__(self, expr: Callable, strength: int): | |
| self.idx: Optional[int] = None | |
| self.expr = expr | |
| self._strength = strength | |
| def __call__(self, *args, **kwargs): | |
| return self.expr(*args) | |
| def strength(self): | |
| return self._strength | |
| def __repr__(self): | |
| return f"ArgsMatcherExpr({self.expr}, {self.strength})" | |
| class DynamicDispatcher: | |
| def __init__(self): | |
| self.table: List[Tuple[Sequence[ArgsMatcher], Callable]] = [] | |
| def register(self, match_args, fn): | |
| self.table.append((match_args, fn)) | |
| def dispatch(self, *args): | |
| max_match_strength = -1 | |
| max_match_cardinality = -1 | |
| max_match_fn: Optional[Callable] = None | |
| max_non_unique = False | |
| current_matchers = [] | |
| for matchers, fn in self.table: | |
| match = True | |
| match_strength_arg_current_max = 0 | |
| match_strength_cardinality = {} | |
| for matcher in matchers: | |
| if matcher(*args): | |
| match_strength = matcher.strength() | |
| match_strength_arg_current_max = max( | |
| match_strength_arg_current_max, match_strength | |
| ) | |
| if match_strength not in match_strength_cardinality: | |
| match_strength_cardinality[match_strength] = 0 | |
| match_strength_cardinality[match_strength] += 1 | |
| else: | |
| match = False | |
| break | |
| if match and match_strength_arg_current_max > max_match_strength: | |
| max_match_fn = fn | |
| max_match_strength = match_strength_arg_current_max | |
| max_match_cardinality = match_strength_cardinality[ | |
| match_strength_arg_current_max | |
| ] | |
| max_non_unique = False | |
| current_matchers = [matchers] | |
| elif ( | |
| match | |
| and match_strength_arg_current_max == max_match_strength | |
| and match_strength_cardinality[match_strength_arg_current_max] | |
| > max_match_cardinality | |
| ): | |
| max_match_fn = fn | |
| max_match_strength = match_strength_arg_current_max | |
| max_match_cardinality = match_strength_cardinality[ | |
| match_strength_arg_current_max | |
| ] | |
| max_non_unique = False | |
| current_matchers = [matchers] | |
| elif ( | |
| match | |
| and match_strength_arg_current_max == max_match_strength | |
| and match_strength_cardinality[match_strength_arg_current_max] | |
| == max_match_cardinality | |
| ): | |
| max_non_unique = True | |
| current_matchers.append(matchers) | |
| if max_non_unique: | |
| raise ValueError( | |
| "Non unique dynamic dispatch solution (cardinality=%d): %s" | |
| % (len(current_matchers), current_matchers) | |
| ) | |
| if max_match_fn is None: | |
| raise NotImplementedError("No valid dynamic dispatch could be found.") | |
| return max_match_fn(*args) | |
| def dispatch_decorator(self, *args_matchers): | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| match_args_flat = [] | |
| for i, arg_matchers in enumerate(args_matchers): | |
| for arg_matcher in arg_matchers: | |
| arg_matcher.idx = i | |
| match_args_flat.append(arg_matcher) | |
| self.register(match_args_flat, func) | |
| return wrapper | |
| return decorator | |
| class AddDynamicDispatcher(DynamicDispatcher): | |
| def __call__(self, x, y): | |
| return self.dispatch(x, y) | |
| """ | |
| @lru_cache(maxsize=100) | |
| def dispatch(self, *args): | |
| return super().dispatch(*args) | |
| """ | |
| add = AddDynamicDispatcher() | |
| @add.dispatch_decorator( | |
| [ArgsMatcherType(str)], | |
| [ArgsMatcherType(str)], | |
| ) | |
| def add_str_str(x, y): | |
| return f"{x}{y}" | |
| @add.dispatch_decorator( | |
| [ArgsMatcherValue("ABC")], | |
| [ArgsMatcherType(str)], | |
| ) | |
| def add_str_str(x, y): | |
| return f"{x}{y.upper()}" | |
| @add.dispatch_decorator( | |
| [ArgsMatcherType(int)], | |
| [ArgsMatcherType(int)], | |
| ) | |
| def add_int_int(x, y): | |
| return x + y | |
| print(add("abc", "def")) | |
| print(add(1, 1)) | |
| print(add("ABC", "def")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment