Last active
February 6, 2025 15:18
-
-
Save has2k1/45c7dde27b10df8fb10604250d9d3656 to your computer and use it in GitHub Desktop.
Automatic registration for single dispatch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "e78720bf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from functools import singledispatch\n", | |
"from contextlib import suppress" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "26787b9b", | |
"metadata": {}, | |
"source": [ | |
"Regular single dispatch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "2207ec50", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"IntNumeric -> 2\n", | |
"FloatNumeric -> 4.0\n", | |
"ComplexNumeric -> (2+3j)\n" | |
] | |
} | |
], | |
"source": [ | |
"class Base:\n", | |
" def __init__(self, value=None):\n", | |
" self.value = value\n", | |
" \n", | |
" def who(self):\n", | |
" print(f\"{self.__class__.__name__} -> {self.value}\")\n", | |
" \n", | |
"@singledispatch\n", | |
"def make_numeric(value):\n", | |
" raise NotImplementedError(\"Nada\")\n", | |
"\n", | |
"@make_numeric.register(int)\n", | |
"class IntNumeric(Base):\n", | |
" pass\n", | |
"\n", | |
"@make_numeric.register(float)\n", | |
"class FloatNumeric(Base):\n", | |
" pass\n", | |
"\n", | |
"@make_numeric.register(complex)\n", | |
"class ComplexNumeric(Base):\n", | |
" pass\n", | |
"\n", | |
"#@make_numeric.register(unknown)\n", | |
"#class UnknownNumeric(Base):\n", | |
"# pass\n", | |
"\n", | |
"i = make_numeric(2)\n", | |
"f = make_numeric(4.0)\n", | |
"c = make_numeric(2+3j)\n", | |
"\n", | |
"i.who()\n", | |
"f.who()\n", | |
"c.who()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "55172db5", | |
"metadata": {}, | |
"source": [ | |
"Single dispatch with the implementation class storing the `type` it specialises on as a regular parameter." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "87e74e21", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"IntNumeric -> 2\n", | |
"FloatNumeric -> 4.0\n", | |
"ComplexNumeric -> (2+3j)\n" | |
] | |
} | |
], | |
"source": [ | |
"class Base:\n", | |
" def __init__(self, value=None):\n", | |
" self.value = value\n", | |
" \n", | |
" def who(self):\n", | |
" print(f\"{self.__class__.__name__} -> {self.value}\")\n", | |
" \n", | |
"@singledispatch\n", | |
"def make_numeric(value):\n", | |
" raise NotImplementedError(\"Nada\")\n", | |
"\n", | |
"class IntNumeric(Base):\n", | |
" base_class = int\n", | |
"\n", | |
"class FloatNumeric(Base):\n", | |
" base_class = float\n", | |
"\n", | |
"class ComplexNumeric(Base):\n", | |
" base_class = complex\n", | |
"\n", | |
"#class UnknownNumeric(Base):\n", | |
"# base_class = unknown\n", | |
" \n", | |
"make_numeric.register(IntNumeric.base_class, IntNumeric)\n", | |
"make_numeric.register(FloatNumeric.base_class, FloatNumeric)\n", | |
"make_numeric.register(ComplexNumeric.base_class, ComplexNumeric)\n", | |
"\n", | |
"i = make_numeric(2)\n", | |
"f = make_numeric(4.0)\n", | |
"c = make_numeric(2+3j)\n", | |
"\n", | |
"i.who()\n", | |
"f.who()\n", | |
"c.who()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "77ef07e3", | |
"metadata": {}, | |
"source": [ | |
"Single dispatch with the implementation class holding the `type` it specialise on in a `staticmethod`. While the method hides an unknown type when the class is created, any unknown types get exposed with the explicit calls to `register`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "f0174eb5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"IntNumeric -> 2\n", | |
"FloatNumeric -> 4.0\n", | |
"ComplexNumeric -> (2+3j)\n" | |
] | |
} | |
], | |
"source": [ | |
"class Base: # This is the same as `class Base(metaclass=type)`\n", | |
" def __init__(self, value=None):\n", | |
" self.value = value\n", | |
" \n", | |
" def who(self):\n", | |
" print(f\"{self.__class__.__name__} -> {self.value}\")\n", | |
" \n", | |
"@singledispatch\n", | |
"def make_numeric(value):\n", | |
" raise NotImplementedError(\"Nada\")\n", | |
"\n", | |
"\n", | |
"class IntNumeric(Base):\n", | |
" base_class = staticmethod(lambda: int)\n", | |
"\n", | |
"\n", | |
"class FloatNumeric(Base):\n", | |
" base_class = staticmethod(lambda: float)\n", | |
"\n", | |
"\n", | |
"class ComplexNumeric(Base):\n", | |
" base_class = staticmethod(lambda: complex)\n", | |
"\n", | |
"class UnknownNumeric(Base):\n", | |
" base_class = staticmethod(lambda: unknown)\n", | |
" \n", | |
"make_numeric.register(IntNumeric.base_class(), IntNumeric)\n", | |
"make_numeric.register(FloatNumeric.base_class(), FloatNumeric)\n", | |
"make_numeric.register(ComplexNumeric.base_class(), ComplexNumeric)\n", | |
"#make_numeric.register(UnknownNumeric.base_class(), UnknownNumeric)\n", | |
"\n", | |
"i = make_numeric(2)\n", | |
"f = make_numeric(4.0)\n", | |
"c = make_numeric(2+3j)\n", | |
"\n", | |
"i.who()\n", | |
"f.who()\n", | |
"c.who()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "079c7817", | |
"metadata": {}, | |
"source": [ | |
"Single dispatch using a metaclass to register all implementation classes. The implementations for unknown types will not be registered. The helper method for single dispatch is hidden and the base class (think `VetiverHandler`) takes on the role." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "75f9147e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"IntNumeric -> 2\n", | |
"FloatNumeric -> 4.0\n", | |
"ComplexNumeric -> (2+3j)\n" | |
] | |
} | |
], | |
"source": [ | |
"@singledispatch\n", | |
"def make_numeric(value):\n", | |
" raise NotImplementedError(\"Nada\")\n", | |
" \n", | |
"class AutoRegisterHandler(type): # inheriting from type/metaclass creates another metaclass\n", | |
" # __new__ of a metaclass is invoked when a new class is being created\n", | |
" def __new__(meta, name, bases, clsdict):\n", | |
" cls = super().__new__(meta, name, bases, clsdict)\n", | |
" with suppress(AttributeError, NameError):\n", | |
" make_numeric.register(cls.base_class(), cls)\n", | |
" return cls\n", | |
"\n", | |
"\n", | |
"class Base(metaclass=AutoRegisterHandler):\n", | |
" # __new__ of a regular class is invoked before the object is instantied,\n", | |
" # the object will be of the class it returns\n", | |
" def __new__(cls, value=None):\n", | |
" implementation_cls = make_numeric.registry[type(value)]\n", | |
" return super().__new__(implementation_cls)\n", | |
" \n", | |
" def __init__(self, value=None):\n", | |
" self.value = value\n", | |
" \n", | |
" def who(self):\n", | |
" print(f\"{self.__class__.__name__} -> {self.value}\")\n", | |
"\n", | |
"class IntNumeric(Base): # type\n", | |
" base_class = staticmethod(lambda: int)\n", | |
"\n", | |
"class FloatNumeric(Base):\n", | |
" base_class = staticmethod(lambda: float)\n", | |
"\n", | |
"class ComplexNumeric(Base):\n", | |
" base_class = staticmethod(lambda: complex)\n", | |
"\n", | |
"#unknown = str\n", | |
"#del unknown\n", | |
"class Unknown(Base):\n", | |
" base_class = staticmethod(lambda: unknown)\n", | |
"\n", | |
"i = Base(2)\n", | |
"f = Base(4.0)\n", | |
"c = Base(2+3j)\n", | |
"#u = Base('u')\n", | |
"\n", | |
"i.who()\n", | |
"f.who()\n", | |
"c.who()\n", | |
"#u.who()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a98a2e07", | |
"metadata": {}, | |
"source": [ | |
"Using `__init__subclass__` instead of a metaclass. Thanks to [machow](https://github.com/machow)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "3a92c752", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"IntNumeric -> 2\n", | |
"FloatNumeric -> 4.0\n", | |
"ComplexNumeric -> (2+3j)\n" | |
] | |
} | |
], | |
"source": [ | |
"@singledispatch\n", | |
"def make_numeric(value):\n", | |
" raise NotImplementedError(\"Nada\")\n", | |
"\n", | |
"class Base:\n", | |
" # Register the specialising implementation subclass when it is created\n", | |
" @classmethod\n", | |
" def __init_subclass__(cls, **kwargs):\n", | |
" super().__init_subclass__(**kwargs)\n", | |
" with suppress(AttributeError, NameError):\n", | |
" make_numeric.register(cls.base_class(), cls)\n", | |
" \n", | |
" # __new__ of a regular class is invoked before the object is instantied,\n", | |
" # the object will be of the class it returns\n", | |
" def __new__(cls, value=None):\n", | |
" implementation_cls = make_numeric.registry[type(value)]\n", | |
" return super().__new__(implementation_cls)\n", | |
" \n", | |
" def __init__(self, value=None):\n", | |
" self.value = value\n", | |
" \n", | |
" def who(self):\n", | |
" print(f\"{self.__class__.__name__} -> {self.value}\")\n", | |
"\n", | |
"class IntNumeric(Base): # type\n", | |
" base_class = staticmethod(lambda: int)\n", | |
"\n", | |
"class FloatNumeric(Base):\n", | |
" base_class = staticmethod(lambda: float)\n", | |
"\n", | |
"class ComplexNumeric(Base):\n", | |
" base_class = staticmethod(lambda: complex)\n", | |
"\n", | |
"#unknown = str\n", | |
"#del unknown\n", | |
"class Unknown(Base):\n", | |
" base_class = staticmethod(lambda: unknown)\n", | |
"\n", | |
"i = Base(2)\n", | |
"f = Base(4.0)\n", | |
"c = Base(2+3j)\n", | |
"#u = Base('u')\n", | |
"\n", | |
"i.who()\n", | |
"f.who()\n", | |
"c.who()\n", | |
"#u.who()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment