Created
November 3, 2018 20:18
-
-
Save skrungly/e0c539f172440d661225b90a8dd72c32 to your computer and use it in GitHub Desktop.
flag enums
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numbers import Integral | |
class Flag(int): | |
def __new__(cls, value: int, parent_cls: type = None): | |
flag = super().__new__(cls, value) | |
if parent_cls is None: | |
parent_cls = cls | |
flag.parent_cls = parent_cls | |
return flag | |
def __or__(self, other): | |
return self.parent_cls(int(self) | other) | |
def __str__(self): | |
return str(int(self)) | |
def __repr__(self): | |
cls_name = type(self).__name__ | |
return "{0}(value={1}, parent_cls={2})".format( | |
cls_name, self, self.parent_cls.__name__ | |
) | |
class MetaFlags(type): | |
def __new__(cls, name, bases, namespace, **kwargs): | |
flags_enum = super().__new__(cls, name, bases, namespace, **kwargs) | |
all_flags = {} | |
for attr in dir(flags_enum): | |
if not attr.startswith("_"): | |
value = getattr(flags_enum, attr) | |
if isinstance(value, Integral): | |
power_of_two = value == 2 ** (int(value).bit_length() - 1) | |
if value < 1 or not power_of_two: | |
raise ValueError( | |
"flag {0} has an invalid value ({1})" | |
.format(attr, value) | |
) | |
all_flags[attr] = value | |
setattr(flags_enum, attr, Flag(value, flags_enum)) | |
flags_enum._flags_ = all_flags | |
return flags_enum | |
class FlagEnum(metaclass=MetaFlags): | |
def __init__(self, int_flags): | |
self._int_flags = int_flags | |
def __int__(self): | |
return self._int_flags | |
def __or__(self, other): | |
cls = type(self) | |
return cls(int(self) | int(other)) | |
def __getattribute__(self, name): | |
_flags_ = object.__getattribute__(self, "_flags_") | |
if name in _flags_: | |
flag_int = _flags_[name] & int(self) | |
return bool(flag_int) | |
return object.__getattribute__(self, name) | |
def __repr__(self): | |
cls_name = type(self).__name__ | |
return "{0}(int_flags={1})".format(cls_name, self._int_flags) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment