Created
September 8, 2021 23:14
-
-
Save mattjj/d1d6d951882406625785db26adc76e96 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, TypeVar | |
from collections import defaultdict | |
def ensure_tuple(x): | |
return x if isinstance(x, tuple) else (x,) | |
def safe_zip(*args): | |
x, *xs = args | |
assert all(len(x_) == len(x) for x_ in xs) | |
return list(zip(*args)) | |
def gettype(x): | |
if isinstance(x, tuple): | |
return tuple[tuple(map(gettype, x))] | |
elif isinstance(x, list): | |
return list[tuple(map(gettype, x))] | |
elif isinstance(x, dict): | |
keytype, = set(map(gettype, x)) | |
valtype, = set(map(gettype, x.values())) | |
return dict[keytype, valtype] | |
else: | |
return type(x) | |
class Typeclass(type): | |
# When defining a typeclass, we can parameterize `Typeclass` to introduce | |
# type variables (binders). | |
@classmethod | |
def __class_getitem__(metacls, key): | |
key = ensure_tuple(key) | |
assert all(isinstance(k, TypeVar) for k in key) | |
metacls.types = key | |
return metacls | |
# This is called when defining a typeclass (i.e. instantiating a Python class | |
# with this as its metaclass) | |
def __new__(metacls, classname, bases, classdict): | |
# Check that annotations use in-scope type variables (introduced in | |
# __class_getitem__ above). | |
for name, ty in classdict['__annotations__'].items(): | |
for t in ty.__args__: | |
if isinstance(t, TypeVar) and t not in metacls.types: | |
raise Exception(f"typevar not in scope: {t}") | |
expected_method_types = dict(classdict['__annotations__']) | |
# Instantiate the class, adding 'type' to its bases to make it a metaclass | |
# too. | |
result = type.__new__(metacls, classname, (type, *bases), classdict) | |
# When defining a typeclass instance, we can parameterize `result` to bind | |
# types to type variables. | |
@classmethod | |
def getitem(metacls_, key): | |
metacls_.types = dict(safe_zip(metacls.types, ensure_tuple(key))) | |
return metacls_ | |
result.__class_getitem__ = getitem | |
# Set up the dispatch mechanism, with handlers to be registered below. | |
handlers = defaultdict(dict) | |
def make_handler(name): | |
def handle(*args): | |
types = tuple(map(gettype, args)) | |
handler = handlers[name].get(types) | |
if handler: return handler(*args) | |
else: | |
raise Exception(f"no handler for method {name} and types {types}") | |
return handle | |
for name in expected_method_types: | |
setattr(result, name, make_handler(name)) | |
# This is called when defining a typeclass instance. It checks for type | |
# annotation disagreements, then registers the handlers. | |
def new(metacls_, classname, bases, classdict): | |
methods = {name: v for name, v in classdict.items() if callable(v)} | |
# Check that we don't have too few or too many methods. | |
extra_methods = set(methods) - set(expected_method_types) | |
if extra_methods: | |
raise Exception(f"extra methods: {', '.join(map(str, extra_methods))}") | |
missing_methods = set(expected_method_types) - set(methods) | |
if missing_methods: | |
raise Exception(f"missing methods: {', '.join(map(str, missing_methods))}") | |
# Check the type annotations on this instance's methods match those in the | |
# typeclass definition, after substituting type values for type variables. | |
for name, v in methods.items(): | |
subst = map(metacls_.types.get, expected_method_types[name].__parameters__) | |
expected_ty = expected_method_types[name][tuple(subst)] | |
annotations = list(v.__annotations__.values()) | |
observed_ty = Callable[annotations[:-1], annotations[-1]] | |
if expected_ty != observed_ty: | |
raise Exception("instance method has incorrect type annotation: " | |
f"expected {expected_ty}, observed {observed_ty}") | |
# Register the new handlers defined by this instance. | |
for name, v in methods.items(): | |
arg_types = tuple(v.__annotations__.values())[:-1] | |
handlers[name][arg_types] = v | |
result.__new__ = new | |
return result | |
### | |
a = TypeVar('a') | |
# typeclass Show a | |
class Show(metaclass=Typeclass[a]): | |
show: Callable[[a], str] | |
show = Show.show | |
# instance Show Int where | |
# show i = str(i) | |
class IntShow(metaclass=Show[int]): | |
def show(x: int) -> str: | |
return str(x) | |
print(show(3)) | |
try: show('hi') | |
except Exception as e: print(e) | |
class StrShow(metaclass=Show[str]): | |
def show(x: str) -> str: | |
return x | |
print(show('hi')) | |
# typeclass Eq a | |
class Eq(metaclass=Typeclass[a]): | |
eq: Callable[[a, a], bool] | |
eq = Eq.eq | |
# instance Eq int where | |
# eq i j = i == j | |
class IntEq(metaclass=Eq[int]): | |
def eq(x: int, y: int) -> bool: | |
return x == y | |
print('====') | |
print(eq(1, 2)) | |
print(eq(1, 1)) | |
class Eq2(metaclass=Typeclass[a]): | |
eq2: Callable[[tuple[a, a]], bool] | |
eq2 = Eq2.eq2 | |
class IntEq2(metaclass=Eq2[int]): | |
def eq2(xy: tuple[int, int]) -> bool: | |
x, y = xy | |
return x == y | |
print('=====') | |
print(eq2((1, 1))) | |
try: eq2(('hi', 'hi')) | |
except Exception as e: print(e) | |
else: pass | |
class StrEq2(metaclass=Eq2[str]): | |
def eq2(xy: tuple[str, str]) -> bool: | |
x, y = xy | |
return x == y | |
print(eq2(('hi', 'hi'))) | |
### | |
# 1. declare typeclasses in a nice way | |
# (a) in a block, | |
# (b) using type annotations (including variables) | |
# 2. declare instances in a nice way, | |
# (a) in a block? | |
# (b) getting checks automatically | |
# | |
# 1a basically means we should use a class | |
# 1b means metaclasses | |
# 2a seems reasonable, alternative is just to register | |
# What's not handled now? | |
# * type variables in instances' methods, like a Show instance for List | |
# (i.e. solve a type inference problem on dispatch) | |
# * output-type-based dispatch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment