Created
May 14, 2019 15:18
-
-
Save ckarnell/23e43ea1155adbf6121fd2f1695d106c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Callable, Optional, Type, TypeVar | |
from mypy.plugin import Plugin, FunctionContext # pylint: disable=no-name-in-module | |
from mypy.types import TypedDictType # pylint: disable=no-name-in-module | |
A = TypeVar('A') | |
def safe_typeddict_cast(_: Type[A], data: Any) -> A: | |
# This function does nothing on its own, but the types are checked in mypy_plugin.py. | |
# The goal is to safely convert the data argument to the type of the first argument. | |
# The type will only be converted if it passes type checks. If it doesn't pass, | |
# error messages will tell you about missing keys or incompatible key types. | |
return data | |
class TypedDictCastPlugin(Plugin): | |
@staticmethod | |
def get_function_hook(fullname: str) -> Optional[Callable[[FunctionContext], TypedDictType]]: | |
def convert_data_type(context: FunctionContext) -> TypedDictType: | |
# Compare the types | |
target_type = context.arg_types[0][0].ret_type.type.typeddict_type # type: ignore | |
target_type_name = context.args[0][0].name # type: ignore | |
target_type_items = target_type.items | |
input_type = context.arg_types[1][0] | |
input_type_items = input_type.items # type: ignore | |
# All keys in the input type must be present and subtyped on the target type | |
missing_keys_in_target_type = [] | |
mistyped_keys_in_target_type = [] | |
for key, value in input_type_items.items(): | |
if key not in target_type_items: | |
missing_keys_in_target_type.append(key) | |
elif value != target_type_items[key]: | |
# TODO: Recursively support checking nested typed dicts for better errors? | |
mistyped_keys_in_target_type.append(key) | |
if missing_keys_in_target_type: | |
if len(missing_keys_in_target_type) == 1: | |
context.api.fail( | |
f'Input data type has extra key "{missing_keys_in_target_type[0]}" ' | |
f'which is missing on type "{target_type_name}"', context.context | |
) | |
else: | |
joined_keys = '"' + '", "'.join(missing_keys_in_target_type) + '"' | |
context.api.fail( | |
f'Input data type has extra keys {joined_keys} which are missing on type "{target_type_name}"', | |
context.context | |
) | |
if mistyped_keys_in_target_type: | |
if len(mistyped_keys_in_target_type) == 1: | |
context.api.fail( | |
f'Input data type has key "{mistyped_keys_in_target_type[0]}" ' | |
f'which has incompatible type on type {target_type_name}', context.context | |
) | |
else: | |
joined_keys = '"' + '", "'.join(mistyped_keys_in_target_type) + '"' | |
context.api.fail( | |
f'Input data type has keys {joined_keys} ' | |
f'which have incompatible types on type "{target_type_name}"', context.context | |
) | |
# TODO: All required keys in the target type must be present and subtyped on the input type | |
# If any type checks failed, don't convert the type | |
if missing_keys_in_target_type or mistyped_keys_in_target_type: | |
return input_type # type: ignore | |
return target_type | |
if '.safe_typeddict_cast' in fullname: | |
return convert_data_type | |
return None | |
def plugin(_: str) -> Type[TypedDictCastPlugin]: | |
return TypedDictCastPlugin |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment