Skip to content

Instantly share code, notes, and snippets.

@kalefranz
Created May 5, 2015 21:45
Show Gist options
  • Save kalefranz/a386d747b7b404d29450 to your computer and use it in GitHub Desktop.
Save kalefranz/a386d747b7b404d29450 to your computer and use it in GitHub Desktop.
Extended Python Enum
"""The Enum class here is a python27 port of the python34 Enum, but extended with extra
functionality.
TODO: The API should be a strict superset of the python34 Enum.
"""
from transcomm.common.crypt import as_base64, from_base64
class EnumType(type):
"""MetaClass for Enum. See Enum for documentation."""
def __init__(cls, name, bases, attr):
super(EnumType, cls).__init__(name, bases, attr)
cls.__options__ = dict(cls.__options__) if hasattr(cls, '__options__') else dict()
cls.__revoptions__ = dict(cls.__revoptions__) if hasattr(cls, '__revoptions__') else dict()
cls.__items__ = dict(cls.__items__) if hasattr(cls, '__items__') else dict()
for name, packed in cls.__dict__.iteritems():
value, item = EnumType.__unpack_options(packed)
if not (name.startswith('_')
or isinstance(value, (property, classmethod, staticmethod))
or callable(value)):
cls.register_constant(name, value, item)
@staticmethod
def __unpack_options(packed):
if isinstance(packed, (tuple, list)) and len(packed) == 2:
value = packed[0]
item = packed[1]
else:
value = packed
item = None
return value, item
def __call__(cls, value):
return cls.from_name(value)
@staticmethod
def set_once(dikt, key, value):
try:
dikt.__getitem__(key)
raise ValueError("duplicate key '{0}' found with value [{1}]".format(key, value))
except KeyError:
dikt[key] = value
def __getitem__(cls, value):
return cls.__items__[value]
def __getattribute__(cls, value):
if value.startswith('__'):
return super(EnumType, cls).__getattribute__(value)
elif value in cls.__options__:
return cls.__options__[value]
else:
return super(EnumType, cls).__getattribute__(value)
class Enum(object):
"""Each constant of this Enum has a name, value, and optional associated object. The name
can be any valid python variable name. The value can be any hashable object. The optional
associated object can be anything.
Examples:
>>> class Mynum(Enum):
... ZERO = 0
... ONE = 1
... TWO = 'two'
>>> Mynum.name(1)
'ONE'
>>> Mynum.from_name('ONE')
1
>>> Mynum('ONE')
1
>>> Mynum('ONE') == Mynum.ONE
True
>>> Mynum.is_valid_value(1)
True
>>> Mynum.from_value(3)
Traceback (most recent call last):
...
TypeError: Enum Mynum has no valid value 3.
>>> Mynum(3)
Traceback (most recent call last):
...
TypeError: Enum Mynum has no 3.
>>> Mynum.ONE
1
>>> type(Mynum.ONE)
<type 'int'>
>>> print repr(Mynum.ONE)
1
>>> print str(Mynum.ONE)
1
>>> class Mynum(Enum):
... ONE = 1, str
... TWO = 'two', complex
>>> Mynum.ONE
1
>>> Mynum[1]
<type 'str'>
>>> Mynum.name(1)
'ONE'
"""
__metaclass__ = EnumType
@classmethod
def is_valid_name(cls, name):
try:
cls.from_name(name)
return True
except TypeError:
return False
@classmethod
def is_valid_value(cls, value):
try:
cls.from_value(value)
return True
except TypeError:
return False
@classmethod
def from_value(cls, value):
if value not in cls.__revoptions__:
raise TypeError("Enum {0} has no valid value {1}.".format(cls.__name__, value))
return value
@classmethod
def name(cls, value):
try:
return cls.__revoptions__[value]
except KeyError:
raise TypeError("Enum {0} has no valid value {1}.".format(cls.__name__, value))
@classmethod
def from_name(cls, name):
try:
return cls.__options__[name]
except KeyError:
raise TypeError("Enum {0} has no {1}.".format(cls.__name__, name))
def __getitem__(self, value):
# implemented on EnumType; here only for a stub to IDEs
raise NotImplementedError() # pragma: no cover
@classmethod
def register_item(cls, value, item):
EnumType.set_once(cls.__items__, value, item)
@classmethod
def register_constant(cls, name, value, item=None):
EnumType.set_once(cls.__options__, name, value)
EnumType.set_once(cls.__revoptions__, value, name)
if item is not None:
cls.register_item(value, item)
# TODO: __str__ != __repr__
# TODO: override __hash__ & __equal__
class BitSet(Enum):
"""
Examples:
>>> class Mybits(BitSet):
... PREPOSITION = 0
... NOUN = 1
... VERB = 2
>>> Mybits.values_to_int((Mybits.NOUN, Mybits.VERB))
6
>>> Mybits.names_to_int(('VERB', 'PREPOSITION'))
5
>>> Mybits.int_to_values(5)
[0, 2]
>>> Mybits.int_to_names(6)
['NOUN', 'VERB']
>>> Mybits.names_to_base64(('VERB', 'PREPOSITION'))
'NQ=='
>>> Mybits.base64_to_values('NQ==')
[0, 2]
>>> Mybits.values_to_base64((0, 1))
'Mw=='
>>> Mybits.base64_to_names('Mw==')
['PREPOSITION', 'NOUN']
"""
# TODO: Enforce values being ints
@classmethod
def values_to_int(cls, iterator):
return reduce(lambda x, y: x + 2**y, iterator, 0)
@classmethod
def int_to_values(cls, integer):
return [n for n, val in enumerate("{0:b}".format(integer)[::-1]) if int(val)]
@classmethod
def names_to_int(cls, iterator):
return cls.values_to_int((cls.from_name(x) for x in iterator))
@classmethod
def int_to_names(cls, integer):
return [cls.name(n) for n in cls.int_to_values(integer)]
@classmethod
def _int_to_base64(cls, integer):
return as_base64(str(integer))
@classmethod
def values_to_base64(cls, iterator):
return cls._int_to_base64(cls.values_to_int(iterator))
@classmethod
def names_to_base64(cls, iterator):
return cls._int_to_base64(cls.names_to_int(iterator))
@classmethod
def _base64_to_int(cls, base64):
return int(from_base64(base64))
@classmethod
def base64_to_values(cls, base64):
return cls.int_to_values(cls._base64_to_int(base64))
@classmethod
def base64_to_names(cls, base64):
return cls.int_to_names(cls._base64_to_int(base64))
from testtools import TestCase, ExpectedException
from transcomm.common.util.enum import EnumType, Enum
class EnumTypeTests(TestCase):
def test_set_once(self):
d = {}
EnumType.set_once(d, 'key', 10)
with ExpectedException(ValueError):
EnumType.set_once(d, 'key', 22)
class ThisEnum(Enum):
ONE, TWO, THREE = range(1, 4)
class EnumTests(TestCase):
def test_is_valid_name(self):
self.assertTrue(ThisEnum.is_valid_name('ONE'))
self.assertTrue(ThisEnum.is_valid_name('TWO'))
self.assertTrue(ThisEnum.is_valid_name('THREE'))
self.assertFalse(ThisEnum.is_valid_name('ON'))
self.assertFalse(ThisEnum.is_valid_name('NE'))
self.assertFalse(ThisEnum.is_valid_name('one'))
self.assertFalse(ThisEnum.is_valid_name('FOUR'))
def test_is_valid_value(self):
self.assertTrue(ThisEnum.is_valid_value(1))
self.assertTrue(ThisEnum.is_valid_value(2))
self.assertTrue(ThisEnum.is_valid_value(3))
self.assertFalse(ThisEnum.is_valid_value('1'))
self.assertFalse(ThisEnum.is_valid_value(0))
self.assertFalse(ThisEnum.is_valid_value(4))
self.assertFalse(ThisEnum.is_valid_value('ONE'))
self.assertFalse(ThisEnum.is_valid_value("FOUR"))
def test_from_value(self):
self.assertEqual(1, ThisEnum.from_value(1))
self.assertEqual(2, ThisEnum.from_value(2))
self.assertEqual(3, ThisEnum.from_value(3))
with ExpectedException(TypeError):
ThisEnum.from_value(4)
def test_name(self):
self.assertEqual('ONE', ThisEnum.name(1))
self.assertEqual('TWO', ThisEnum.name(2))
self.assertEqual('THREE', ThisEnum.name(3))
with ExpectedException(TypeError):
ThisEnum.name(4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment