Skip to content

Instantly share code, notes, and snippets.

@has2k1
Last active February 6, 2025 15:18
Show Gist options
  • Save has2k1/45c7dde27b10df8fb10604250d9d3656 to your computer and use it in GitHub Desktop.
Save has2k1/45c7dde27b10df8fb10604250d9d3656 to your computer and use it in GitHub Desktop.
Automatic registration for single dispatch
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e78720bf",
"metadata": {},
"outputs": [],
"source": [
"from functools import singledispatch\n",
"from contextlib import suppress"
]
},
{
"cell_type": "markdown",
"id": "26787b9b",
"metadata": {},
"source": [
"Regular single dispatch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2207ec50",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IntNumeric -> 2\n",
"FloatNumeric -> 4.0\n",
"ComplexNumeric -> (2+3j)\n"
]
}
],
"source": [
"class Base:\n",
" def __init__(self, value=None):\n",
" self.value = value\n",
" \n",
" def who(self):\n",
" print(f\"{self.__class__.__name__} -> {self.value}\")\n",
" \n",
"@singledispatch\n",
"def make_numeric(value):\n",
" raise NotImplementedError(\"Nada\")\n",
"\n",
"@make_numeric.register(int)\n",
"class IntNumeric(Base):\n",
" pass\n",
"\n",
"@make_numeric.register(float)\n",
"class FloatNumeric(Base):\n",
" pass\n",
"\n",
"@make_numeric.register(complex)\n",
"class ComplexNumeric(Base):\n",
" pass\n",
"\n",
"#@make_numeric.register(unknown)\n",
"#class UnknownNumeric(Base):\n",
"# pass\n",
"\n",
"i = make_numeric(2)\n",
"f = make_numeric(4.0)\n",
"c = make_numeric(2+3j)\n",
"\n",
"i.who()\n",
"f.who()\n",
"c.who()"
]
},
{
"cell_type": "markdown",
"id": "55172db5",
"metadata": {},
"source": [
"Single dispatch with the implementation class storing the `type` it specialises on as a regular parameter."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "87e74e21",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IntNumeric -> 2\n",
"FloatNumeric -> 4.0\n",
"ComplexNumeric -> (2+3j)\n"
]
}
],
"source": [
"class Base:\n",
" def __init__(self, value=None):\n",
" self.value = value\n",
" \n",
" def who(self):\n",
" print(f\"{self.__class__.__name__} -> {self.value}\")\n",
" \n",
"@singledispatch\n",
"def make_numeric(value):\n",
" raise NotImplementedError(\"Nada\")\n",
"\n",
"class IntNumeric(Base):\n",
" base_class = int\n",
"\n",
"class FloatNumeric(Base):\n",
" base_class = float\n",
"\n",
"class ComplexNumeric(Base):\n",
" base_class = complex\n",
"\n",
"#class UnknownNumeric(Base):\n",
"# base_class = unknown\n",
" \n",
"make_numeric.register(IntNumeric.base_class, IntNumeric)\n",
"make_numeric.register(FloatNumeric.base_class, FloatNumeric)\n",
"make_numeric.register(ComplexNumeric.base_class, ComplexNumeric)\n",
"\n",
"i = make_numeric(2)\n",
"f = make_numeric(4.0)\n",
"c = make_numeric(2+3j)\n",
"\n",
"i.who()\n",
"f.who()\n",
"c.who()"
]
},
{
"cell_type": "markdown",
"id": "77ef07e3",
"metadata": {},
"source": [
"Single dispatch with the implementation class holding the `type` it specialise on in a `staticmethod`. While the method hides an unknown type when the class is created, any unknown types get exposed with the explicit calls to `register`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f0174eb5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IntNumeric -> 2\n",
"FloatNumeric -> 4.0\n",
"ComplexNumeric -> (2+3j)\n"
]
}
],
"source": [
"class Base: # This is the same as `class Base(metaclass=type)`\n",
" def __init__(self, value=None):\n",
" self.value = value\n",
" \n",
" def who(self):\n",
" print(f\"{self.__class__.__name__} -> {self.value}\")\n",
" \n",
"@singledispatch\n",
"def make_numeric(value):\n",
" raise NotImplementedError(\"Nada\")\n",
"\n",
"\n",
"class IntNumeric(Base):\n",
" base_class = staticmethod(lambda: int)\n",
"\n",
"\n",
"class FloatNumeric(Base):\n",
" base_class = staticmethod(lambda: float)\n",
"\n",
"\n",
"class ComplexNumeric(Base):\n",
" base_class = staticmethod(lambda: complex)\n",
"\n",
"class UnknownNumeric(Base):\n",
" base_class = staticmethod(lambda: unknown)\n",
" \n",
"make_numeric.register(IntNumeric.base_class(), IntNumeric)\n",
"make_numeric.register(FloatNumeric.base_class(), FloatNumeric)\n",
"make_numeric.register(ComplexNumeric.base_class(), ComplexNumeric)\n",
"#make_numeric.register(UnknownNumeric.base_class(), UnknownNumeric)\n",
"\n",
"i = make_numeric(2)\n",
"f = make_numeric(4.0)\n",
"c = make_numeric(2+3j)\n",
"\n",
"i.who()\n",
"f.who()\n",
"c.who()"
]
},
{
"cell_type": "markdown",
"id": "079c7817",
"metadata": {},
"source": [
"Single dispatch using a metaclass to register all implementation classes. The implementations for unknown types will not be registered. The helper method for single dispatch is hidden and the base class (think `VetiverHandler`) takes on the role."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "75f9147e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IntNumeric -> 2\n",
"FloatNumeric -> 4.0\n",
"ComplexNumeric -> (2+3j)\n"
]
}
],
"source": [
"@singledispatch\n",
"def make_numeric(value):\n",
" raise NotImplementedError(\"Nada\")\n",
" \n",
"class AutoRegisterHandler(type): # inheriting from type/metaclass creates another metaclass\n",
" # __new__ of a metaclass is invoked when a new class is being created\n",
" def __new__(meta, name, bases, clsdict):\n",
" cls = super().__new__(meta, name, bases, clsdict)\n",
" with suppress(AttributeError, NameError):\n",
" make_numeric.register(cls.base_class(), cls)\n",
" return cls\n",
"\n",
"\n",
"class Base(metaclass=AutoRegisterHandler):\n",
" # __new__ of a regular class is invoked before the object is instantied,\n",
" # the object will be of the class it returns\n",
" def __new__(cls, value=None):\n",
" implementation_cls = make_numeric.registry[type(value)]\n",
" return super().__new__(implementation_cls)\n",
" \n",
" def __init__(self, value=None):\n",
" self.value = value\n",
" \n",
" def who(self):\n",
" print(f\"{self.__class__.__name__} -> {self.value}\")\n",
"\n",
"class IntNumeric(Base): # type\n",
" base_class = staticmethod(lambda: int)\n",
"\n",
"class FloatNumeric(Base):\n",
" base_class = staticmethod(lambda: float)\n",
"\n",
"class ComplexNumeric(Base):\n",
" base_class = staticmethod(lambda: complex)\n",
"\n",
"#unknown = str\n",
"#del unknown\n",
"class Unknown(Base):\n",
" base_class = staticmethod(lambda: unknown)\n",
"\n",
"i = Base(2)\n",
"f = Base(4.0)\n",
"c = Base(2+3j)\n",
"#u = Base('u')\n",
"\n",
"i.who()\n",
"f.who()\n",
"c.who()\n",
"#u.who()"
]
},
{
"cell_type": "markdown",
"id": "a98a2e07",
"metadata": {},
"source": [
"Using `__init__subclass__` instead of a metaclass. Thanks to [machow](https://github.com/machow)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3a92c752",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"IntNumeric -> 2\n",
"FloatNumeric -> 4.0\n",
"ComplexNumeric -> (2+3j)\n"
]
}
],
"source": [
"@singledispatch\n",
"def make_numeric(value):\n",
" raise NotImplementedError(\"Nada\")\n",
"\n",
"class Base:\n",
" # Register the specialising implementation subclass when it is created\n",
" @classmethod\n",
" def __init_subclass__(cls, **kwargs):\n",
" super().__init_subclass__(**kwargs)\n",
" with suppress(AttributeError, NameError):\n",
" make_numeric.register(cls.base_class(), cls)\n",
" \n",
" # __new__ of a regular class is invoked before the object is instantied,\n",
" # the object will be of the class it returns\n",
" def __new__(cls, value=None):\n",
" implementation_cls = make_numeric.registry[type(value)]\n",
" return super().__new__(implementation_cls)\n",
" \n",
" def __init__(self, value=None):\n",
" self.value = value\n",
" \n",
" def who(self):\n",
" print(f\"{self.__class__.__name__} -> {self.value}\")\n",
"\n",
"class IntNumeric(Base): # type\n",
" base_class = staticmethod(lambda: int)\n",
"\n",
"class FloatNumeric(Base):\n",
" base_class = staticmethod(lambda: float)\n",
"\n",
"class ComplexNumeric(Base):\n",
" base_class = staticmethod(lambda: complex)\n",
"\n",
"#unknown = str\n",
"#del unknown\n",
"class Unknown(Base):\n",
" base_class = staticmethod(lambda: unknown)\n",
"\n",
"i = Base(2)\n",
"f = Base(4.0)\n",
"c = Base(2+3j)\n",
"#u = Base('u')\n",
"\n",
"i.who()\n",
"f.who()\n",
"c.who()\n",
"#u.who()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment