Created
August 5, 2022 10:07
-
-
Save gmanny/c809f06c49a2ce96017be181a0c29c9e to your computer and use it in GitHub Desktop.
Python generic type argument resolution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def resolve_type_argument(query_type: type, target_type: type | GenericAlias, argument: TypeVar) -> type | TypeVar: | |
"Resolves a given TypeVar for a generic type `query_type` when supplied by `target_type`" | |
type_params = query_type.__parameters__ | |
type_arguments = resolve_type_arguments(query_type, target_type) | |
params_to_args = {key: value for (key, value) in zip(type_params, type_arguments)} | |
return params_to_args[argument] | |
def resolve_type_arguments(query_type: type, target_type: type | GenericAlias) -> tuple[type | TypeVar, ...]: | |
""" | |
This code was taken from https://stackoverflow.com/a/69862817/579817 | |
Resolves the type arguments of the query type as supplied by the target type of any of its bases. | |
Operates in a tail-recursive fashion, and drills through the hierarchy of generic base types breadth-first in left-to-right order to correctly identify the type arguments that need to be supplied to the next recursive call. | |
raises a TypeError if they target type was not an instance of the query type. | |
:param query_type: Must be supplied without args (e.g. Mapping not Mapping[KT,VT] | |
:param target_type: Must be supplied with args (e.g. Mapping[KT, T] or Mapping[str, int] not Mapping) | |
:return: A tuple of the arguments given via target_type for the type parameters of for the query_type, if it has any parameters, otherwise an empty tuple. These arguments may themselves be TypeVars. | |
""" | |
target_origin = get_origin(target_type) | |
if target_origin is None: | |
if target_type is query_type: | |
return target_type.__parameters__ | |
else: | |
target_origin = target_type | |
supplied_args = None | |
else: | |
supplied_args = get_args(target_type) | |
if target_origin is query_type: | |
return supplied_args | |
param_set = set() | |
param_list = [] | |
for each_base in target_origin.__orig_bases__: | |
each_origin = get_origin(each_base) | |
if each_origin is not None: | |
# each base is of the form class[T], which is a private type _GenericAlias, but it is formally documented to have __parameters__ | |
for each_param in each_base.__parameters__: | |
if each_param not in param_set: | |
param_set.add(each_param) | |
param_list.append(each_param) | |
if issubclass(each_origin, query_type): | |
if supplied_args is not None and len(supplied_args) > 0: | |
params_to_args = {key: value for (key, value) in zip(param_list, supplied_args)} | |
resolved_args = tuple(params_to_args[each] for each in each_base.__parameters__) | |
return resolve_type_arguments( | |
query_type, each_base[resolved_args] | |
) # each_base[args] fowards the args to each_base, it is not quite equivalent to GenericAlias(each_origin, resolved_args) | |
else: | |
return resolve_type_arguments(query_type, each_base) | |
elif issubclass(each_base, query_type): | |
return resolve_type_arguments(query_type, each_base) | |
if not issubclass(target_origin, query_type): | |
raise ValueError(f"{target_type} is not a subclass of {query_type}") | |
else: | |
return () |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from types import GenericAlias | |
from typing import Generic, TypeVar | |
from resolve_type_arguments import resolve_type_argument, resolve_type_arguments | |
T = TypeVar("T") | |
U = TypeVar("U") | |
Q = TypeVar("Q") | |
R = TypeVar("R") | |
W = TypeVar("W") | |
X = TypeVar("X") | |
Y = TypeVar("Y") | |
Z = TypeVar("Z") | |
class A(Generic[T, U, Q, R]): | |
... | |
class NestedA(Generic[T, U, Q]): | |
... | |
class NestedB(Generic[T]): | |
... | |
class NoParams: | |
... | |
class B(NoParams, NestedA[U, Q, U], A[int, NestedA[Q, Q, Q], Q, U], NestedB[R]): | |
... | |
class C(B[T, str, int]): | |
... | |
class D(C[int]): | |
... | |
class E(D): | |
... | |
class F(E): | |
... | |
class G(Generic[T]): | |
... | |
class H(Generic[T]): | |
... | |
class I(G[int]): | |
... | |
class J(I, H[str]): | |
... | |
def test_resolve_type_arguments(): | |
""" | |
Various test cases for resolve_type_arguments | |
Taken from examples in https://stackoverflow.com/a/69862817/579817 | |
""" | |
def verify_type_arguments(query_type: type, target_type: type | GenericAlias, *verify_strs: str) -> None: | |
arg_tuple = resolve_type_arguments(query_type, target_type) | |
if len(verify_strs) == 0: | |
verify_str = "()" | |
elif len(verify_strs) == 1: | |
verify_str = f"({verify_strs[0]},)" | |
else: | |
verify_str = f"({', '.join(verify_strs)})" | |
verify_str = verify_str.replace("__main__", __name__) | |
assert str(arg_tuple) == verify_str | |
verify_type_arguments(A, A, "~T", "~U", "~Q", "~R") | |
verify_type_arguments(A, A[W, X, Y, Z], "~W", "~X", "~Y", "~Z") | |
verify_type_arguments(A, B, "<class 'int'>", "__main__.NestedA[~Q, ~Q, ~Q]", "~Q", "~U") | |
verify_type_arguments(A, B[W, X, Y], "<class 'int'>", "__main__.NestedA[~X, ~X, ~X]", "~X", "~W") | |
verify_type_arguments(B, B, "~U", "~Q", "~R") | |
verify_type_arguments(B, B[W, X, Y], "~W", "~X", "~Y") | |
verify_type_arguments(A, C, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~T") | |
verify_type_arguments(A, C[W], "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~W") | |
verify_type_arguments(B, C, "~T", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, C[W], "~W", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, C, "~T") | |
verify_type_arguments(C, C[W], "~W") | |
verify_type_arguments(A, D, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, D, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, D, "<class 'int'>") | |
verify_type_arguments(D, D) | |
verify_type_arguments(A, E, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, E, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, E, "<class 'int'>") | |
verify_type_arguments(D, E) | |
verify_type_arguments(E, E) | |
verify_type_arguments(A, F, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, F, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, F, "<class 'int'>") | |
verify_type_arguments(D, F) | |
verify_type_arguments(E, F) | |
verify_type_arguments(F, F) | |
verify_type_arguments(G, J, "<class 'int'>") | |
def test_resolve_type_argument(): | |
"Test resolving a single type argument" | |
def verify_type_arguments(query_type: type, target_type: type | GenericAlias, *verify_strs: str) -> None: | |
parameters = query_type.__parameters__ | |
assert len(parameters) == len(verify_strs) | |
for parameter, verify_str in zip(parameters, verify_strs): | |
argument = resolve_type_argument(query_type, target_type, parameter) | |
verify_str = verify_str.replace("__main__", __name__) | |
assert str(argument) == verify_str | |
verify_type_arguments(A, A, "~T", "~U", "~Q", "~R") | |
verify_type_arguments(A, A[W, X, Y, Z], "~W", "~X", "~Y", "~Z") | |
verify_type_arguments(A, B, "<class 'int'>", "__main__.NestedA[~Q, ~Q, ~Q]", "~Q", "~U") | |
verify_type_arguments(A, B[W, X, Y], "<class 'int'>", "__main__.NestedA[~X, ~X, ~X]", "~X", "~W") | |
verify_type_arguments(B, B, "~U", "~Q", "~R") | |
verify_type_arguments(B, B[W, X, Y], "~W", "~X", "~Y") | |
verify_type_arguments(A, C, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~T") | |
verify_type_arguments(A, C[W], "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "~W") | |
verify_type_arguments(B, C, "~T", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, C[W], "~W", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, C, "~T") | |
verify_type_arguments(C, C[W], "~W") | |
verify_type_arguments(A, D, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, D, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, D, "<class 'int'>") | |
verify_type_arguments(D, D) | |
verify_type_arguments(A, E, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, E, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, E, "<class 'int'>") | |
verify_type_arguments(D, E) | |
verify_type_arguments(E, E) | |
verify_type_arguments(A, F, "<class 'int'>", "__main__.NestedA[str, str, str]", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(B, F, "<class 'int'>", "<class 'str'>", "<class 'int'>") | |
verify_type_arguments(C, F, "<class 'int'>") | |
verify_type_arguments(D, F) | |
verify_type_arguments(E, F) | |
verify_type_arguments(F, F) | |
verify_type_arguments(G, J, "<class 'int'>") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment