Created
July 3, 2020 13:58
-
-
Save saulshanabrook/e354f6d5a7fcbb9fea2494ab95c364ae to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Functions and tools to unify types. | |
Note that this uses dicts as ordered sets in many places so that input and output | |
has the same ordering, so that git diffs are better | |
""" | |
from __future__ import annotations | |
import typing | |
import collections | |
import pydantic | |
import abc | |
import ast | |
__all__ = [ | |
"create_type", | |
"OutputType", | |
"unify", | |
"annotation", | |
"NoneOutput", | |
"StringOutput", | |
"ListOutput", | |
"TupleOutput", | |
"DictOutput", | |
"ObjectOutput", | |
"OtherOutput", | |
"ModuleOutput", | |
"SliceOutput", | |
"TypeOutput", | |
"FunctionOutput", | |
"MethodWithoutSelfOutput", | |
"MethodOutput", | |
"ClassMethodOutput", | |
"UnionOutput", | |
"BottomOutput", | |
] | |
# If there are more than this amount in a union, just use any | |
MAX_UNION_ITEMS = 5 | |
# If there are more than this amount in a string literal, just use any | |
MAX_STRING_ITEMS = 5 | |
# More than this and it will be tuple of arbitrary length | |
MAX_TUPLE_ITEMS = 2 | |
def annotation(tp: OutputType) -> typing.Optional[ast.AST]: | |
if is_unknown(tp): | |
return None | |
return tp.annotation | |
def create_type(o: object) -> OutputType: | |
try: | |
tp = pydantic.parse_obj_as(InputType, o) # type: ignore | |
except pydantic.error_wrappers.ValidationError: | |
raise ValueError(o) | |
return to_output(tp) | |
def unify_inputs(types: typing.Iterable[InputType]) -> OutputType: | |
return unify(map(to_output, types)) | |
def unify(types: typing.Iterable[OutputType]) -> OutputType: | |
# groupby by the type to collect all types with same kind together | |
tp_tps_to_tps: typing.DefaultDict[ | |
# Use dict for values as ordered set to preserve ordering | |
typing.Type[OutputType], | |
typing.Dict[OutputType, None], | |
] = collections.defaultdict(dict) | |
types = list(types) | |
while types: | |
tp = types.pop() | |
if isinstance(tp, ObjectOutput): | |
return ObjectOutput() | |
# If we have a union, add all to existing types | |
if isinstance(tp, UnionOutput): | |
types.extend(tp.options) | |
continue | |
tp_tps_to_tps[type(tp)][tp] = None | |
# OK now we know that tp_tps_to_tps contains no union types | |
# Now try to unify each of the kinds to give us a set of types to add to to the unified | |
# Used a dict here as an ordered set | |
unified_types: typing.Dict[OutputType, None] = {} | |
# Sort so order of final types is preserved | |
for tp_tp, tps in tp_tps_to_tps.items(): | |
res: OutputType = _unify_output_tp(tp_tp, tps.keys()) | |
if isinstance(res, UnionOutput): | |
for o in res.options: | |
unified_types[o] = None | |
else: | |
unified_types[res] = None | |
if not unified_types: | |
return BottomOutput() | |
if len(unified_types) > MAX_UNION_ITEMS: | |
return ObjectOutput() | |
if len(unified_types) == 1: | |
return next(iter(unified_types.keys())) | |
return UnionOutput(options=tuple(unified_types.keys())) | |
OUTPUT_TYPE = typing.TypeVar("OUTPUT_TYPE", bound="OutputTypeBase") | |
# Define this function for typing purposes | |
def _unify_output_tp( | |
cls: typing.Type[OUTPUT_TYPE], tps: typing.Iterable[OUTPUT_TYPE] | |
) -> OutputType: | |
return cls.unify(tps) | |
class BaseModel(pydantic.BaseModel): | |
class Config: | |
extra = "forbid" | |
# https://github.com/samuelcolvin/pydantic/issues/1303#issuecomment-599712964 | |
def __hash__(self): | |
return hash((type(self),) + tuple(self.__dict__.values())) | |
class OutputTypeBase(BaseModel, abc.ABC): | |
@classmethod | |
@abc.abstractmethod | |
def unify( | |
cls: typing.Type[OUTPUT_TYPE], tps: typing.Iterable[OUTPUT_TYPE] | |
) -> OutputType: | |
... | |
@property | |
@abc.abstractmethod | |
def annotation(self) -> ast.AST: | |
... | |
@property | |
def module(self) -> typing.Optional[str]: | |
return None | |
class InputTypeBase(BaseModel, abc.ABC): | |
@abc.abstractmethod | |
def to_output(self) -> OutputType: | |
... | |
def to_output(tp: InputType) -> OutputType: | |
if tp is None: | |
return NoneOutput() | |
return tp.to_output() | |
def is_unknown(tp: object) -> bool: | |
return isinstance(tp, (ObjectOutput, BottomOutput)) | |
class NoneOutput(OutputTypeBase): | |
type: typing.Literal["None"] = "None" | |
@classmethod | |
def unify(cls, tps: typing.Iterable[NoneOutput]) -> NoneOutput: | |
return NoneOutput() | |
@property | |
def annotation(self): | |
return ast.Constant(None, None) | |
class StringInput(InputTypeBase): | |
__root__: str | |
def to_output(self) -> StringOutput: | |
return StringOutput(options=[self.__root__]) | |
class StringOutput(OutputTypeBase): | |
""" | |
>>> str(StringOutput()) | |
str | |
>>> str(StringOutput(options=["hi", "there"])) | |
Literal["hi", "there"] | |
""" | |
type: typing.Literal["str"] = "str" | |
options: typing.Union[typing.Tuple[str, ...], None] = None | |
@property | |
def annotation(self) -> ast.AST: | |
if self.options: | |
return ast.Subscript( | |
ast.Name("Literal", ast.Load()), | |
ast.Index( | |
ast.Tuple([ast.Constant(s, None) for s in self.options], ast.Load()) | |
), | |
ast.Load(), | |
) | |
return f'Literal[{", ".join(map(str, self.options))}]' | |
return ast.Name("str", ast.Load()) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[StringOutput]) -> StringOutput: | |
options: typing.Dict[str, None] = {} | |
for tp in tps: | |
if tp.options is None: | |
return StringOutput() | |
for o in tp.options: | |
options[o] = None | |
if len(options) > MAX_STRING_ITEMS: | |
return StringOutput() | |
return StringOutput(options=tuple(options.keys())) | |
class ListInput(InputTypeBase): | |
__root__: typing.List[InputType] | |
def to_output(self) -> ListOutput: | |
return ListOutput(item=unify(map(to_output, self.__root__))) | |
class ListOutput(OutputTypeBase): | |
type: typing.Literal["list"] = "list" | |
item: OutputType | |
@property | |
def annotation(self) -> ast.AST: | |
if is_unknown(self.item): | |
return ast.Name("list", ast.Load()) | |
return ast.Subscript( | |
ast.Name("List", ast.Load()), ast.Index(self.item.annotation), ast.Load() | |
) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[ListOutput]) -> ListOutput: | |
return ListOutput(item=unify(tp.item for tp in tps)) | |
class TupleInput(InputTypeBase): | |
t: typing.Literal["tuple"] | |
v: typing.List[InputType] | |
def to_output(self) -> TupleOutput: | |
return TupleOutput(items=list(map(to_output, self.v))) | |
class TupleOutput(OutputTypeBase): | |
type: typing.Literal["tuple"] = "tuple" | |
# If just one item then tuple of arbitrary length of all the same type | |
items: typing.Union[OutputType, typing.Tuple[OutputType, ...]] | |
@property | |
def annotation(self) -> ast.AST: | |
if is_unknown(self.items): | |
return ast.Name("tuple", ast.Load()) | |
if isinstance(self.items, tuple): | |
return ast.Subscript( | |
ast.Name("Tuple", ast.Load()), | |
ast.Index(ast.Tuple([s.annotation for s in self.items], ast.Load())), | |
ast.Load(), | |
) | |
return ast.Subscript( | |
ast.Name("Tuple", ast.Load()), | |
ast.Index( | |
ast.Tuple([self.items.annotation, ast.Constant(..., None)], ast.Load()) | |
), | |
ast.Load(), | |
) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[TupleOutput]) -> TupleOutput: | |
lengths = {len(tp.items) if isinstance(tp.items, tuple) else None for tp in tps} | |
# only should be None if we have a fixed length tuple | |
length = None if None in lengths or len(lengths) != 1 else lengths.pop() | |
if length is not None and length > MAX_UNION_ITEMS: | |
length = None | |
if length is None: | |
possible_values: typing.List[OutputType] = [] | |
for tp in tps: | |
items = tp.items | |
if isinstance(items, tuple): | |
possible_values.extend(items) | |
else: | |
possible_values.append(items) | |
return TupleOutput(items=unify(possible_values)) | |
i_to_possible_values: typing.List[typing.List[OutputType]] = [ | |
[] for _ in range(length) | |
] | |
for tp in tps: | |
items = tp.items | |
assert isinstance(items, tuple) | |
for i, item in enumerate(items): | |
i_to_possible_values[i].append(item) | |
return TupleOutput(items=tuple(map(unify, i_to_possible_values))) | |
class DictInput(InputTypeBase): | |
t: typing.Literal["dict"] | |
v: typing.List[typing.Tuple[InputType, InputType]] | |
def to_output(self) -> DictOutput: | |
if self.v: | |
key, value = map(unify_inputs, zip(*self.v)) | |
else: | |
key = BottomOutput() | |
value = BottomOutput() | |
return DictOutput(key=key, value=value) | |
class DictOutput(OutputTypeBase): | |
type: typing.Literal["dict"] = "dict" | |
key: OutputType | |
value: OutputType | |
@property | |
def annotation(self) -> ast.AST: | |
if is_unknown(self.key) and is_unknown(self.value): | |
return ast.Name("dict", ast.Load()) | |
return ast.Subscript( | |
ast.Name("Dict", ast.Load()), | |
ast.Index( | |
ast.Tuple([self.key.annotation, self.value.annotation], ast.Load()) | |
), | |
ast.Load(), | |
) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[DictOutput]) -> DictOutput: | |
return DictOutput( | |
key=unify(tp.key for tp in tps), value=unify(tp.value for tp in tps) | |
) | |
class ObjectOutput(OutputTypeBase): | |
type: typing.Literal["object"] = "object" | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("object", ast.Load()) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[ObjectOutput]) -> ObjectOutput: | |
return ObjectOutput() | |
class BuiltinNamedInput(BaseModel): | |
__root__: str | |
class ModuleNamedInput(BaseModel): | |
module: str | |
name: str | |
NamedInput = typing.Union[BuiltinNamedInput, ModuleNamedInput] | |
class OtherInputType(InputTypeBase): | |
t: NamedInput | |
def to_output(self) -> typing.Union[OtherOutput, ObjectOutput]: | |
return OtherOutput.safe_create(self.t) | |
class NamedOutput(BaseModel): | |
# None if builtin | |
module: typing.Optional[str] = None | |
name: str | |
@classmethod | |
def from_input(cls, input: NamedInput) -> typing.Optional[NamedOutput]: | |
if isinstance(input, BuiltinNamedInput): | |
return cls(name=input.__root__) | |
if "<" in input.name: | |
return None | |
return cls(name=input.name, module=input.module) | |
@property | |
def annotation(self) -> ast.AST: | |
if self.module is None: | |
return ast.Name(self.name, ast.Load()) | |
return ast.Attribute(ast.Name(self.module, ast.Load()), self.name, ast.Load()) | |
class OtherOutput(OutputTypeBase): | |
type: NamedOutput | |
@classmethod | |
def safe_create(cls, tp_i: NamedInput) -> typing.Union[OtherOutput, ObjectOutput]: | |
tp = NamedOutput.from_input(tp_i) | |
if tp is None: | |
return ObjectOutput() | |
return cls(type=tp) | |
@property | |
def annotation(self) -> ast.AST: | |
return self.type.annotation | |
@classmethod | |
def unify( | |
cls, tps: typing.Iterable[OtherOutput] | |
) -> typing.Union[OtherOutput, UnionOutput]: | |
tps = {t: None for t in tps} | |
if len(tps) == 1: | |
return next(iter(tps.keys())) | |
return UnionOutput(options=tuple(tps.keys())) | |
@property | |
def module(self) -> typing.Optional[str]: | |
return self.type.module | |
class ModuleInput(InputTypeBase): | |
t: typing.Literal["module"] | |
v: str | |
def to_output(self) -> ModuleOutput: | |
return ModuleOutput(name=self.v) | |
class ModuleOutput(OutputTypeBase): | |
type: typing.Literal["module"] = "module" | |
name: typing.Optional[str] = None | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Attribute(ast.Name("types", ast.Load()), "ModuleType", ast.Load()) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[ModuleOutput]) -> ModuleOutput: | |
names = set(tps) | |
if len(names) == 1: | |
return names.pop() | |
return ModuleOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
return self.name | |
class SliceInput(InputTypeBase): | |
t: typing.Literal["slice"] | |
v: SliceInputValue | |
def to_output(self) -> SliceOutput: | |
v = self.v | |
return SliceOutput( | |
start=to_output(v.start), stop=to_output(v.stop), step=to_output(v.step), | |
) | |
class SliceInputValue(BaseModel): | |
start: InputType | |
stop: InputType | |
step: InputType | |
class SliceOutput(OutputTypeBase): | |
type: typing.Literal["slice"] = "slice" | |
start: OutputType | |
stop: OutputType | |
step: OutputType | |
@property | |
def annotation(self) -> ast.AST: | |
""" | |
Doesn't exist yet as generic type, but should | |
https://github.com/python/typing/issues/159 | |
""" | |
if is_unknown(self.start) and is_unknown(self.stop) and is_unknown(self.step): | |
return ast.Name("slice", ast.Load()) | |
return ast.Subscript( | |
ast.Name("slice", ast.Load()), | |
ast.Index( | |
ast.Tuple( | |
[self.start.annotation, self.stop.annotation, self.step.annotation], | |
ast.Load(), | |
) | |
), | |
ast.Load(), | |
) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[SliceOutput]) -> SliceOutput: | |
start: typing.List[OutputType] = [] | |
stop: typing.List[OutputType] = [] | |
step: typing.List[OutputType] = [] | |
for tp in tps: | |
start.append(tp.start) | |
stop.append(tp.stop) | |
step.append(tp.step) | |
return SliceOutput(start=unify(start), step=unify(step), stop=unify(stop)) | |
class TypeInput(InputTypeBase): | |
t: typing.Literal["type"] | |
v: NamedInput | |
def to_output(self) -> TypeOutput: | |
return TypeOutput(name=NamedOutput.from_input(self.v)) | |
class TypeOutput(OutputTypeBase): | |
type: typing.Literal["type"] = "type" | |
# If none, then any type | |
name: typing.Optional[NamedOutput] = None | |
@property | |
def annotation(self) -> ast.AST: | |
if self.name is None: | |
return ast.Name("type", ast.Load()) | |
return ast.Subscript( | |
ast.Name("Type", ast.Load()), ast.Index(self.name.annotation), ast.Load() | |
) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[TypeOutput]) -> TypeOutput: | |
names = set(tp.name for tp in tps) | |
if len(names) == 1: | |
return TypeOutput(name=names.pop()) | |
return TypeOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
if self.name: | |
return self.name.module | |
return None | |
class FunctionInput(InputTypeBase): | |
t: typing.Literal["function"] | |
v: NamedInput | |
def to_output(self) -> typing.Union[FunctionOutput, MethodWithoutSelfOutput]: | |
name = NamedOutput.from_input(self.v) | |
# We are in some lambda | |
if not name: | |
return FunctionOutput() | |
if "." in name.name: | |
# For some reason happens with MaskedArray.mean | |
classname, methodname = name.name.split(".") | |
return MethodWithoutSelfOutput( | |
name=methodname, class_=NamedOutput(name=classname, module=name.module) | |
) | |
return FunctionOutput(name=name) | |
class MethodWithoutSelfOutput(OutputTypeBase): | |
type: typing.Literal["method_no_self"] = "method_no_self" | |
class_: NamedOutput | |
name: str | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("Callable", ast.Load()) | |
@classmethod | |
def unify( | |
cls, tps: typing.Iterable[MethodWithoutSelfOutput] | |
) -> typing.Union[MethodWithoutSelfOutput, FunctionOutput]: | |
tps = set(tps) | |
if len(tps) == 1: | |
return tps.pop() | |
return FunctionOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
return self.class_.name | |
class FunctionOutput(OutputTypeBase): | |
type: typing.Literal["function"] = "function" | |
# If none, then any function | |
name: typing.Optional[NamedOutput] = None | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("Callable", ast.Load()) | |
@classmethod | |
def unify(cls, tps: typing.Iterable[FunctionOutput]) -> FunctionOutput: | |
names = set(tp.name for tp in tps) | |
if len(names) == 1: | |
return FunctionOutput(name=names.pop()) | |
return FunctionOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
if self.name: | |
return self.name.module | |
return None | |
class BuiltinFunctionInput(InputTypeBase): | |
t: typing.Literal["builtin_function_or_method"] = "builtin_function_or_method" | |
v: NamedInput | |
def to_output(self) -> FunctionOutput: | |
return FunctionOutput(name=NamedOutput.from_input(self.v)) | |
class BuiltinMethodInput(InputTypeBase): | |
t: typing.Literal["builtin_function_or_method"] = "builtin_function_or_method" | |
v: MethodInputValue | |
def to_output(self) -> MethodOutput: | |
return MethodOutput(name=self.v.name, self=to_output(self.v.self)) | |
class MethodInputValue(BaseModel): # type: ignore | |
name: str | |
self: InputType | |
class MethodOutput(OutputTypeBase): # type: ignore | |
type: typing.Literal["method"] = "method" | |
name: str | |
self: OutputType | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("Callable", ast.Load()) | |
@classmethod | |
def unify( | |
cls, tps: typing.Iterable[MethodOutput] | |
) -> typing.Union[MethodOutput, FunctionOutput]: | |
tps = set(tps) | |
if len(tps) == 1: | |
return tps.pop() | |
return FunctionOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
return self.self.module | |
class MethodInput(InputTypeBase): | |
t: typing.Literal["method"] | |
v: MethodInputValue | |
def to_output(self) -> MethodOutput: | |
return MethodOutput(name=self.v.name, self=to_output(self.v.self)) | |
class MethodDescriptorInput(InputTypeBase): | |
t: typing.Literal["method_descriptor"] | |
v: MethodDescriptorInputValue | |
def to_output(self) -> ClassMethodOutput: | |
return ClassMethodOutput( | |
name=self.v.name, class_=self.v.class_.to_output().name | |
) | |
class MethodDescriptorInputValue(BaseModel): | |
name: str | |
class_: TypeInput = pydantic.Field(alias="class") | |
class ClassMethodOutput(OutputTypeBase): # type: ignore | |
type: typing.Literal["classmethod"] = "classmethod" | |
class_: NamedOutput | |
name: str | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("Callable", ast.Load()) | |
@classmethod | |
def unify( | |
cls, tps: typing.Iterable[ClassMethodOutput] | |
) -> typing.Union[ClassMethodOutput, FunctionOutput]: | |
tps = set(tps) | |
if len(tps) == 1: | |
return tps.pop() | |
return FunctionOutput() | |
@property | |
def module(self) -> typing.Optional[str]: | |
return self.class_.module | |
class NumpyUfuncInput(InputTypeBase): | |
t: NumpyUfuncInputType | |
v: str | |
def to_output(self) -> FunctionOutput: | |
return FunctionOutput(name=NamedOutput(module="numpy", name=self.v)) | |
class NumpyUfuncInputType(BaseModel): | |
module: typing.Literal["numpy"] | |
name: typing.Literal["ufunc"] | |
class NumpyConvert2MAInput(InputTypeBase): | |
t: NumpyConvert2MAInputType | |
v: str | |
def to_output(self) -> FunctionOutput: | |
return FunctionOutput(name=NamedOutput(module="numpy.ma", name=self.v)) | |
class NumpyConvert2MAInputType(BaseModel): | |
module: typing.Literal["numpy.ma.core"] | |
name: typing.Literal["_convert2ma"] | |
class NumpyNDArrayInput(InputTypeBase): | |
t: ModuleNamedInput | |
v: NumpyNDArrayValue | |
def to_output(self) -> typing.Union[OtherOutput, ObjectOutput]: | |
return OtherOutput.safe_create(self.t) | |
class NumpyNDArrayValue(BaseModel): | |
dtype: str | |
class NumpyDTypeInput(InputTypeBase): | |
t: ModuleNamedInput | |
v: str | |
def to_output(self) -> typing.Union[OtherOutput, ObjectOutput]: | |
return OtherOutput.safe_create(self.t) | |
class UnionOutput(OutputTypeBase): | |
type: typing.Literal["union"] = "union" | |
# Can't be sets because when serializing, serialized as dicts | |
options: typing.Tuple[OutputType, ...] | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Subscript( | |
ast.Name("Union", ast.Load()), | |
ast.Index(ast.Tuple([o.annotation for o in self.options], ast.Load())), | |
ast.Load(), | |
) | |
@classmethod | |
def unify(cls, unions: typing.Iterable[UnionOutput]) -> OutputType: | |
# This should never be called | |
raise NotImplementedError() | |
class BottomOutput(OutputTypeBase): | |
""" | |
Like any but represents an unkown type, so a union with it is always the other | |
""" | |
type: typing.Literal["bottom"] = "bottom" | |
@classmethod | |
def unify(cls, tps: typing.Iterable[BottomOutput]) -> UnionOutput: | |
# So that when unified will give back no more options | |
return UnionOutput(options=[]) | |
@property | |
def annotation(self) -> ast.AST: | |
return ast.Name("object", ast.Load()) | |
# Make them unions not subclasses so they are closed | |
# and pydantic will iterate through to find right one | |
InputType = typing.Union[ | |
None, | |
StringInput, | |
ListInput, | |
TupleInput, | |
DictInput, | |
OtherInputType, | |
ModuleInput, | |
SliceInput, | |
TypeInput, | |
FunctionInput, | |
BuiltinFunctionInput, | |
BuiltinMethodInput, | |
MethodInput, | |
MethodDescriptorInput, | |
NumpyUfuncInput, | |
NumpyConvert2MAInput, | |
NumpyNDArrayInput, | |
NumpyDTypeInput, | |
] | |
OutputType = typing.Union[ | |
NoneOutput, | |
StringOutput, | |
ListOutput, | |
TupleOutput, | |
DictOutput, | |
ObjectOutput, | |
OtherOutput, | |
ModuleOutput, | |
SliceOutput, | |
TypeOutput, | |
FunctionOutput, | |
MethodWithoutSelfOutput, | |
MethodOutput, | |
ClassMethodOutput, | |
UnionOutput, | |
BottomOutput, | |
] | |
for cls in (InputType.__args__ + OutputType.__args__) + ( # type: ignore | |
SliceInputValue, | |
MethodInputValue, | |
): | |
if cls is type(None): | |
continue | |
cls.update_forward_refs() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment