Skip to content

Instantly share code, notes, and snippets.

@ckarnell
Created May 14, 2019 15:18
Show Gist options
  • Save ckarnell/23e43ea1155adbf6121fd2f1695d106c to your computer and use it in GitHub Desktop.
Save ckarnell/23e43ea1155adbf6121fd2f1695d106c to your computer and use it in GitHub Desktop.
from typing import Any, Callable, Optional, Type, TypeVar
from mypy.plugin import Plugin, FunctionContext # pylint: disable=no-name-in-module
from mypy.types import TypedDictType # pylint: disable=no-name-in-module
A = TypeVar('A')
def safe_typeddict_cast(_: Type[A], data: Any) -> A:
# This function does nothing on its own, but the types are checked in mypy_plugin.py.
# The goal is to safely convert the data argument to the type of the first argument.
# The type will only be converted if it passes type checks. If it doesn't pass,
# error messages will tell you about missing keys or incompatible key types.
return data
class TypedDictCastPlugin(Plugin):
@staticmethod
def get_function_hook(fullname: str) -> Optional[Callable[[FunctionContext], TypedDictType]]:
def convert_data_type(context: FunctionContext) -> TypedDictType:
# Compare the types
target_type = context.arg_types[0][0].ret_type.type.typeddict_type # type: ignore
target_type_name = context.args[0][0].name # type: ignore
target_type_items = target_type.items
input_type = context.arg_types[1][0]
input_type_items = input_type.items # type: ignore
# All keys in the input type must be present and subtyped on the target type
missing_keys_in_target_type = []
mistyped_keys_in_target_type = []
for key, value in input_type_items.items():
if key not in target_type_items:
missing_keys_in_target_type.append(key)
elif value != target_type_items[key]:
# TODO: Recursively support checking nested typed dicts for better errors?
mistyped_keys_in_target_type.append(key)
if missing_keys_in_target_type:
if len(missing_keys_in_target_type) == 1:
context.api.fail(
f'Input data type has extra key "{missing_keys_in_target_type[0]}" '
f'which is missing on type "{target_type_name}"', context.context
)
else:
joined_keys = '"' + '", "'.join(missing_keys_in_target_type) + '"'
context.api.fail(
f'Input data type has extra keys {joined_keys} which are missing on type "{target_type_name}"',
context.context
)
if mistyped_keys_in_target_type:
if len(mistyped_keys_in_target_type) == 1:
context.api.fail(
f'Input data type has key "{mistyped_keys_in_target_type[0]}" '
f'which has incompatible type on type {target_type_name}', context.context
)
else:
joined_keys = '"' + '", "'.join(mistyped_keys_in_target_type) + '"'
context.api.fail(
f'Input data type has keys {joined_keys} '
f'which have incompatible types on type "{target_type_name}"', context.context
)
# TODO: All required keys in the target type must be present and subtyped on the input type
# If any type checks failed, don't convert the type
if missing_keys_in_target_type or mistyped_keys_in_target_type:
return input_type # type: ignore
return target_type
if '.safe_typeddict_cast' in fullname:
return convert_data_type
return None
def plugin(_: str) -> Type[TypedDictCastPlugin]:
return TypedDictCastPlugin
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment