-
-
Save huzecong/df51502a8a6ec0bcc0e605a2ce109008 to your computer and use it in GitHub Desktop.
# Copyright (c) 2021 Zecong Hu | |
# | |
# Permission to use, copy, modify, and/or distribute this software for any | |
# purpose with or without fee is hereby granted. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH | |
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY | |
# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, | |
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM | |
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR | |
# OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR | |
# PERFORMANCE OF THIS SOFTWARE. | |
import collections | |
import typing | |
__all__ = [ | |
"Options", | |
] | |
class OptionsMeta(typing.NamedTupleMeta): | |
def __new__(mcs, typename, bases, namespace): | |
if namespace.get('_root', False): | |
# The created class is `Options`, skip. | |
return super().__new__(mcs, typename, bases, namespace) | |
# Gather fields from annotations of current class and base classes. | |
cur_fields = namespace.get('__annotations__', {}) | |
fields = {} | |
field_sources = {} # which base class does the name came from | |
field_defaults = {} | |
for base in bases: | |
if issubclass(base, Options) and hasattr(base, '_fields'): | |
# Base class is a concrete subclass of `Options`. | |
for name in base._fields: | |
if name in cur_fields: | |
# Make sure not to overwrite redefined fields. | |
continue | |
if name in fields: | |
# Overlapping field that is not redefined. | |
raise TypeError( | |
f"Base class {base} contains field {name}, which " | |
f"is defined in other base class " | |
f"{field_sources[name]}") | |
fields[name] = base.__annotations__[name] | |
field_sources[name] = base | |
if name in base._field_defaults: | |
field_defaults[name] = base._field_defaults[name] | |
fields.update(cur_fields) | |
if len(fields) == 0: | |
raise ValueError("Options class must contain at least one field") | |
for name, value in field_defaults.items(): | |
namespace.setdefault(name, value) | |
# Reorder fields to put those without default values in front. | |
fields_with_default = [name for name in fields if name in namespace] | |
reordered_fields = (sorted(set(fields).difference(fields_with_default)) + | |
sorted(fields_with_default)) | |
namespace['__annotations__'] = collections.OrderedDict( | |
[(name, fields[name]) for name in reordered_fields]) | |
# Let `NamedTupleMeta` create a annotated `namedtuple` for us. | |
# Note that `bases` is not used here so we just set it to `None`. | |
nm_tpl = super().__new__(mcs, typename, None, namespace) | |
# Rewrite `__new__` method to make all arguments keyword-only. | |
# This is very hacky code. Do not try this at home. | |
arg_list = ''.join(name + ', ' # watch out for singleton tuples | |
for name in reordered_fields) | |
s = (f""" | |
def __new__(_cls, *args, {arg_list}): | |
if len(args) > 0: | |
raise TypeError("Instances of Options class must be created " | |
"with keyword arguments.") | |
return _tuple_new(_cls, ({arg_list})) | |
""").strip() | |
new_method_namespace = {'_tuple_new': tuple.__new__, | |
'__name__': f'namedtuple_{typename}'} | |
exec(s, new_method_namespace) | |
__new__ = new_method_namespace['__new__'] | |
__new__.__qualname__ = f"{typename}.__new__" | |
__new__.__doc__ = nm_tpl.__new__.__doc__ | |
__new__.__annotations__ = nm_tpl.__new__.__annotations__ | |
__new__.__kwdefaults__ = {name: namespace[name] | |
for name in fields_with_default} | |
nm_tpl.__new__ = __new__ | |
# Wrap the return type in `OptionsMeta` so it can be subclassed. | |
new_namespace = nm_tpl.__dict__.copy() | |
new_namespace['_bases'] = bases | |
# Also keep base classes of the `namedtuple` (i.e., the `tuple` class), | |
# so we can call `tuple.__new__`. | |
options_type = type.__new__(mcs, typename, nm_tpl.__bases__, new_namespace) | |
options_type.__bases__ = tuple(options_type.__bases__) | |
return options_type | |
def mro(cls): | |
default_mro = super().mro() | |
# `Options` does not define `_bases`, so we don't do anything about it. | |
if hasattr(cls, '_bases'): | |
# `default_mro` should be `[cls, tuple, object]`. | |
# `c3merge` and `c3mro` are implementations of the C3 linearization | |
# algorithm, which unluckily aren't provided as APIs. | |
return c3merge([ | |
default_mro[:1], | |
*[base.__mro__ for base in cls._bases], | |
default_mro[1:]]) | |
return default_mro | |
class Options(metaclass=OptionsMeta): | |
_root = True | |
def __new__(cls, *args, **kwargs): | |
# Copied from typing.Generic. | |
if cls is Options: | |
# Prevent instantiation of `Options` class. | |
raise TypeError("Type Options cannot be instantiated; " | |
"it can be used only as a base class") | |
if (super().__new__ is object.__new__ and | |
cls.__init__ is not object.__init__): | |
obj = super().__new__(cls) | |
else: | |
obj = super().__new__(cls, *args, **kwargs) | |
return obj | |
def c3merge(sequences): | |
r"""Adapted from https://www.python.org/download/releases/2.3/mro/""" | |
# Make sure we don't actually mutate anything we are getting as input. | |
sequences = [list(x) for x in sequences] | |
result = [] | |
while True: | |
# Clear out blank sequences. | |
sequences = [x for x in sequences if x] | |
if not sequences: | |
return result | |
# Find the first clean head. | |
for seq in sequences: | |
head = seq[0] | |
# If this is not a bad head (i.e., not in any other sequence) | |
if not any(head in s[1:] for s in sequences): | |
break | |
else: | |
raise Error("inconsistent hierarchy") | |
# Move the head from the front of all sequences to the end of results. | |
result.append(head) | |
for seq in sequences: | |
if seq[0] == head: | |
del seq[0] | |
return result |
There's no built-in way to support inheritance for NamedTuple
s per se, mostly because it's still a tuple that supports __getitem__
via an index, and there's no well-defined behavior for adding fields to a tuple. But, as I said in the previous comments, you could consider using attrs
or dataclasses
for more or less the same functionality. The attrs
docs even has a page that compares it with namedtuples.
Thanks for your reply! Actually, I stumbled over this code by searching for
NamedTuple
with inheritance support. Any recommendation for that?
@MatthiasLohr A while back I experimented with that too. As far as I remember, I got it to work on Python 3.7+. Ultimately, I ended up going with dataclasses since support for them is simply much better compared to a custom NamedTuple implementation. If it helps you, here's what I did. No guarantees though 🙂
"""Requires Python 3.7 -> preserve dict insertion order"""
from __future__ import annotations
import sys
import typing
# attributes prohibited to set in NamedTuple class syntax
_prohibited = frozenset({'__new__', '__init__', '__slots__', '__getnewargs__',
'_fields', '_field_defaults',
'_make', '_replace', '_asdict', '_source'})
_special = frozenset({'__module__', '__name__', '__annotations__'})
class NamedTupleMeta(type):
def __new__(cls, typename, bases, ns):
types = ns.get('__annotations__', {})
default_names = []
for field_name in types:
if field_name in ns:
default_names.append(field_name)
elif default_names:
raise TypeError(f"Non-default namedtuple field {field_name} "
f"cannot follow default field"
f"{'s' if len(default_names) > 1 else ''} "
f"{', '.join(default_names)}")
defaults = tuple(ns[n] for n in default_names)
if sys.version_info >= (3, 9):
nm_tpl = typing._make_nmtuple(typename, types.items(),
defaults=defaults,
module=ns['__module__'])
else:
nm_tpl = typing._make_nmtuple(typename, types.items())
nm_tpl.__new__.__annotations__ = dict(types)
nm_tpl.__new__.__defaults__ = defaults
nm_tpl._field_defaults = {n: ns[n] for n in default_names}
# update from user namespace without overriding special namedtuple attributes
for key in ns:
if key in _prohibited:
raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
if key not in _special and key not in nm_tpl._fields:
setattr(nm_tpl, key, ns[key])
return nm_tpl
class OptionsMeta(NamedTupleMeta):
def __new__(cls, typename, bases, ns):
cur_fields = ns.get("__annotations__", {})
fields = {}
field_sources = {}
field_defaults = {}
for base in bases:
if hasattr(base, "_fields"):
for name in base._fields:
if name in cur_fields:
# Don't overwrite redefined fields
continue
if name in fields:
# Overlapping field that is not redefined.
raise TypeError(
f"Base class {base} contains field {name}, which "
f"is defined in other base class "
f"{field_sources[name]}")
fields[name] = base.__annotations__[name]
field_sources[name] = base
if name in base._field_defaults:
field_defaults[name] = base._field_defaults[name]
fields.update(cur_fields)
if len(fields) == 0:
raise ValueError("Options class must contain at least one field")
for name, value in field_defaults.items():
ns.setdefault(name, value)
# Reorder fields to put those without default values in front.
fields_with_default = [name for name in fields if name in ns]
annotations = {name: val for name, val in fields.items()
if name not in fields_with_default}
annotations.update({name: val for name, val in fields.items()
if name in fields_with_default})
ns["__annotations__"] = annotations
nm_tpl = super().__new__(cls, typename, None, ns)
bases = bases + nm_tpl.__bases__
return type.__new__(cls, typename, bases, nm_tpl.__dict__.copy())
def Options():
raise TypeError("Options can only be used as base class")
_Options = type.__new__(OptionsMeta, 'Options', (), {})
Options.__mro_entries__ = lambda bases: (_Options,)
Thanks for your reply! Actually, I stumbled over this code by searching for
NamedTuple
with inheritance support. Any recommendation for that?