Skip to content

Instantly share code, notes, and snippets.

@pyrocat101
Created June 3, 2015 00:22
Show Gist options
  • Save pyrocat101/3627f9722346fb3af371 to your computer and use it in GitHub Desktop.
Save pyrocat101/3627f9722346fb3af371 to your computer and use it in GitHub Desktop.
Record as data container, optimized for access by name.
class Record(type):
"""
Metaclass for a class with named fields.
>>> class Point(object):
... __metaclass__ = Record
... __fields__ = ('x', 'y')
...
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p.x + p.y # access fields by name
33
>>> list(p) # iterable by the order of fields
[11, 22]
>>> d = p.to_dict() # convert to a dictionary
>>> d['x']
11
>>> Point(11, 22) == Point(11, y=22) # support comparison
True
>>> Point(11, 22) == Point(22, 11)
False
>>> p = Point(**d) # convert from a dictionary
>>> p # readable repr
Point(x=11, y=22)
"""
def __new__(mcs, classname, bases, classdict):
if '__fields__' not in classdict:
raise ValueError("Record must have __fields__ attribute.")
def _new(clazz, *args, **kwargs):
expected = len(clazz.__fields__)
received = len(args) + len(kwargs)
if received != expected:
raise TypeError('__new__() takes exactly %d arguments (%d given)', expected, received)
return super(cls, cls).__new__(cls, *args, **kwargs)
def _init(self, *args, **kwargs):
for field, value in zip(self.__fields__, args):
self.__setattr__(field, value)
for field, value in kwargs.iteritems():
if field not in self.__fields__:
raise ValueError("__init__() got and unexpected keyword argument '%s'" % field)
self.__setattr__(field, value)
classdict.update(mcs.METHODS)
classdict['__new__'] = _new
classdict['__init__'] = _init
cls = type.__new__(mcs, classname, bases, classdict)
return cls
def _repr(self):
return "%s(%s)" % (
self.__class__.__name__,
', '.join(
'%s=%s' % (field, self.__getattribute__(field)) for field in self.__fields__
)
)
def _eq(self, y):
if type(self) != type(y):
return False
return list(self) == list(y)
def _iter(self):
for name in self.__fields__:
yield self.__getattribute__(name)
def _to_dict(self):
return {field: self.__getattribute__(field) for field in self.__fields__}
METHODS = {
'__repr__': _repr,
'__iter__': _iter,
'__eq__': _eq,
'to_dict': _to_dict,
}
if __name__ == '__main__':
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment