Created
January 17, 2013 22:18
-
-
Save michaelbartnett/4560351 to your computer and use it in GitHub Desktop.
mgeutils makes MongoEngine more fun to use ;P
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""mgeutils module | |
Decorators and convenience functions for using mongoengine | |
""" | |
from __future__ import unicode_literals | |
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import logging | |
import bson | |
import mongoengine as mge | |
import datetime | |
import dateutil.parser as dateparser | |
import utils | |
_MGE_SERIALIZE_TO_UNICODE = (bson.ObjectId,) | |
_MGE_SERIALIZABLE_DOCUMENT = (mge.Document, mge.EmbeddedDocument) | |
def _mge_convert_to_json_builtin(obj): | |
"""Attempts to converted MongoEngine objects to | |
something that can be serialized to JSON | |
""" | |
if isinstance(obj, _MGE_SERIALIZABLE_DOCUMENT): | |
if hasattr(obj, 'to_dict'): | |
return obj.to_dict() | |
logging.error("Got a MongoEngine Document, but it wasn't a superdoc") | |
elif isinstance(obj, _MGE_SERIALIZE_TO_UNICODE): | |
return unicode(obj) | |
elif isinstance(obj, mge.ValidationError): | |
result = {} | |
for attr in ('errors', 'field_name', 'message'): | |
val = getattr(obj, attr) | |
if isinstance(val, basestring) and val.startswith('ValidationError('): | |
continue | |
if val: | |
result[attr] = val | |
return result | |
return utils._convert_to_json_builtin(obj) | |
def json_encode(value, *args, **kwargs): | |
"""JSON encoding that supports encoding mongoengine Documents.""" | |
default = kwargs.pop('default', None) or _mge_convert_to_json_builtin | |
return utils.json_encode(value, default=default) | |
class LegitDateTimeField(mge.DateTimeField): | |
def validate(self, value): | |
if not isinstance(value, (datetime.datetime, datetime.date)): | |
try: | |
dateparser.parse(value) | |
except: | |
self.error('Could not parse date {0}'.format(value)) | |
def prepare_query_value(self, op, value): | |
if value is None: | |
return value | |
if isinstance(value, datetime.datetime): | |
return value | |
if isinstance(value, datetime.date): | |
return datetime.datetime(value.year, value.month, value.day) | |
dateparser.parse(value) | |
return super(LegitDateTimeField, self).prepare_query_value(op, value) | |
def __ensure_class_not_hasattr(cls, attr_name): | |
assert not hasattr(cls, attr_name), ( | |
'Class {1} already has {0} attribute. Superdoc needs that name.' | |
''.format(cls, attr_name)) | |
def superdoc(cls): # Decorator | |
"""Decorator that mixes in some useful functions for | |
manipulating MongoEngine documents. | |
""" | |
# Make some guarantees. Why decorate if you've | |
# already defined your helper methods? | |
assert isinstance(cls, mge.base.DocumentMetaclass), ( | |
'Class {0} decorated by superdoc must be a mongoengine.Document' | |
''.format(cls.__name)) | |
__ensure_class_not_hasattr(cls, 'field_names') | |
__ensure_class_not_hasattr(cls, 'reference_fields') | |
__ensure_class_not_hasattr(cls, 'to_dict') | |
__ensure_class_not_hasattr(cls, 'update_fields') | |
@classmethod | |
@utils.restrict_kwargs('include', 'exclude') | |
def field_names(cls, **kwargs): | |
for name in cls._fields.viewkeys(): | |
yield name | |
if hasattr(cls, '__getattr__'): | |
old_getattr = cls.__getattr__ | |
def __getattr__(self, attr_name): | |
try: | |
if attr_name.endswith('__id'): | |
realattr = attr_name[:-4] | |
ref_field = self._data.get(realattr, None) | |
# Only guarantee success for presence in reference_fields | |
# and exact type match | |
if ref_field is None and realattr in self.reference_fields: | |
# Sometimes a ReferenceField may be present but not set | |
return None | |
elif type(ref_field) is bson.ObjectId: | |
return ref_field | |
elif type(ref_field) is bson.DBRef: | |
return ref_field.id | |
# If old_getattr is not defined, then | |
# control will pass down to the "raise AttributeError" | |
return old_getattr(self, attr_name) | |
elif attr_name.endswith('__dbref'): | |
realattr = attr_name[:-7] | |
ref_field = self._data.get(realattr, None) | |
# Only makes sense to return DBRef objects | |
if ref_field is None and realattr in self.reference_fields: | |
return None | |
if type(ref_field is bson.DBRef): | |
return ref_field | |
# If old_getattr is not defined, then | |
# control will pass down to the "raise AttributeError" | |
return old_getattr(self, attr_name) | |
except: | |
# Just continue to the AttributeError | |
pass | |
raise AttributeError( | |
"'{0}' object has no attribute '{1}'" | |
"".format(type(self).__name__, attr_name)) | |
def update_fields(self, **kwargs): | |
field_set = self.get_field_set(exclude='id') | |
for k in kwargs: | |
if k in field_set: | |
setattr(self, k, kwargs[k]) | |
@classmethod | |
@utils.restrict_kwargs('include', 'exclude') | |
def get_field_set(cls, **kwargs): | |
defined_fields = frozenset(cls.field_names()) | |
include = kwargs.get('include', defined_fields) | |
exclude = kwargs.get('exclude', None) | |
if utils.is_iter_not_str(include): | |
include_set = frozenset(include) | |
else: | |
include_set = frozenset((include,)) | |
if utils.is_iter_not_str(exclude): | |
exclude_set = frozenset(exclude) | |
else: | |
exclude_set = frozenset((exclude,)) | |
field_set = include_set - exclude_set | |
if not field_set <= defined_fields: | |
raise ValueError( | |
'The fields {0} are not defined in {1}' | |
''.format(field_set - defined_fields, cls)) | |
return field_set | |
@utils.restrict_kwargs('include', 'exclude', 'include_nulls', 'include_empties') | |
def to_dict(self, **kwargs): | |
include_nulls = kwargs.pop('include_nulls', True) | |
include_empties = kwargs.pop('include_empties', True) | |
field_set = self.get_field_set(**kwargs) | |
result = {} | |
for fieldname in field_set: | |
if fieldname in self.reference_fields: | |
result[fieldname] = getattr(self, '{0}__id'.format(fieldname)) | |
continue | |
if fieldname == 'id': | |
value = self.id | |
else: | |
field = self._fields[fieldname] | |
value = getattr(self, fieldname) | |
if value or (include_empties and value is not None): | |
result[fieldname] = field.to_python(value) | |
elif include_nulls and value is None: | |
result[fieldname] = None | |
else: | |
raise ValueError("Unexpected type found in to_dict call") | |
return result | |
the_ref_fields = { | |
field for field in cls._fields.viewkeys() | |
if type(cls._fields[field]) is mge.ReferenceField} | |
cls.reference_fields = property(lambda cls: the_ref_fields) | |
cls.__getattr__ = __getattr__ | |
cls.field_names = field_names | |
cls.to_dict = to_dict | |
cls.update_fields = update_fields | |
cls.get_field_set = get_field_set | |
return cls |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""""utils module | |
Module containing helper functions used throughout | |
the project, and not tied to a specify library other | |
than what's in the stdlib (2.7). | |
""" | |
from __future__ import unicode_literals | |
from __future__ import absolute_import | |
from __future__ import print_function | |
from __future__ import division | |
import inspect | |
import functools | |
import json | |
import collections | |
import datetime | |
from tornado import escape | |
_RESTRICT_KWARGS_MSG = """\ | |
{fn_name}() does not support a kwarg named {kwarg_name} | |
Supported kwargs are: {supported_kwargs} | |
File "{filename}", line {lineno}, in {caller_fn_name} | |
Context: | |
{context} | |
""" | |
_REQUIRE_KWARGS_MSG = """\ | |
{fn_name}() missing a kwarg named {kwarg_name}. | |
Required kwargs are: {required_kwargs} | |
File "{filename}", line {lineno}, in {caller_fn_name} | |
Context: | |
{context} | |
""" | |
def is_iter_not_str(arg): | |
"""Helper to concisely check if an object is iterable, but not a string.""" | |
return (isinstance(arg, collections.Iterable) and | |
not isinstance(arg, basestring)) | |
def flatten(iter): | |
"""Quickly flatten an iterable. Returns a generator.""" | |
for el in iter: | |
if is_iter_not_str(el): | |
for sub in flatten(el): | |
yield sub | |
else: | |
yield el | |
def flatten_and_split(strings, separator=','): | |
"""For a list of [separator]-separated strings, | |
split the strings by [separator] and flatten the resultant | |
lists into one big list. | |
""" | |
for splitstr in (s.split(',') for s in flatten(strings)): | |
for string in splitstr: | |
if string: | |
yield string | |
def _convert_to_json_builtin(value): | |
if isinstance(value, datetime.datetime): | |
return value.isoformat() | |
elif isinstance(value, collections.Set): | |
return list(value) | |
raise TypeError('Could not serialize type {0}'.format(type(value))) | |
def json_encode(value, *args, **kwargs): | |
"""JSON-encodes the given Python object.""" | |
# JSON permits but does not require forward slashes to be escaped. | |
# This is useful when json data is emitted in a <script> tag | |
# in HTML, as it prevents </script> tags from prematurely terminating | |
# the javscript. Some json libraries do this escaping by default, | |
# although python's standard library does not, so we do it here. | |
# stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped | |
default = kwargs.pop('default', None) or _convert_to_json_builtin | |
value = escape.recursive_unicode(value) | |
json_str = json.dumps(value, default=default, *args, **kwargs) | |
return json_str.replace("</", "<\\/") | |
def json_decode(s): | |
"""JSON-decodes a json-encoded string into a Python object (dict for now). | |
""" | |
return json.loads(s) | |
def restrict_kwargs(*supported_kwargs): | |
"""Raises TypeError if the kwargs passed into the function do not match | |
the list of arguments that you specify. Gets you part of the effect of | |
keyword-only arguments in Python 2.7.x. | |
""" | |
def decorator(fn): | |
@functools.wraps(fn) | |
def fn_with_kwargs_restriction(*args, **kwargs): | |
for kwarg_name in kwargs.viewkeys(): | |
if kwarg_name not in supported_kwargs: | |
stackinfo_keys = ('_fname', 'filename', 'lineno', | |
'caller_fn_name', 'context', '_deth') | |
stackinfo = inspect.stack()[1] | |
msg_dict = dict(zip(stackinfo_keys, stackinfo)) | |
msg_dict.update({ | |
'context': ''.join(msg_dict['context']), | |
'fn_name': fn.__name__, | |
'kwarg_name': kwarg_name, | |
'supported_kwargs': supported_kwargs, | |
}) | |
raise TypeError(_RESTRICT_KWARGS_MSG.format(**msg_dict)) | |
return fn(*args, **kwargs) | |
return fn_with_kwargs_restriction | |
return decorator | |
def require_kwargs(*required_kwargs): | |
"""Raises TypeError if the the kwargs passed into the function do not | |
contain all of the arguments that you specify. Gets you part of the | |
effect of keyword-only arguments in Python 2.7.x. | |
""" | |
def decorator(fn): | |
@functools.wraps(fn) | |
def fn_with_kwargs_requirement(*args, **kwargs): | |
kwarg_keys = list(kwargs.viewkeys()) | |
for kwarg_name in required_kwargs: | |
if kwarg_name not in kwarg_keys: | |
stackinfo_keys = ('_fname', 'filename', 'lineno', | |
'caller_fn_name', 'context', '_deth') | |
stackinfo = inspect.stack()[1] | |
msg_dict = dict(zip(stackinfo_keys, stackinfo)) | |
msg_dict.update({ | |
'context': ''.join(msg_dict['context']), | |
'fn_name': fn.__name__, | |
'kwarg_name': kwarg_name, | |
'required_kwargs': required_kwargs, | |
}) | |
raise TypeError(_REQUIRE_KWARGS_MSG.format(**msg_dict)) | |
return fn(*args, **kwargs) | |
return fn_with_kwargs_requirement | |
return decorator | |
def shallow_memoize(fn): | |
lookup = {} | |
@functools.wraps(fn) | |
def memoized_func(*args, **kwargs): | |
arg_tuple = (args, tuple(kwargs.viewkeys()), tuple(kwargs.viewvalues())) | |
if arg_tuple in lookup: | |
return lookup[arg_tuple] | |
result = fn(*args, **kwargs) | |
lookup[arg_tuple] = result | |
return result | |
return memoized_func | |
@shallow_memoize | |
def sparse_bitcount(val, abs_when_negative=True): | |
count = 0 | |
if abs_when_negative: | |
val = abs(val) | |
elif val < 0: | |
raise ValueError( | |
"Either specify abs_when_negatve=True, or " | |
"don't pass negative values to count_bits.") | |
while val: | |
val &= val - 1 | |
count += 1 | |
return count |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Future readers, please note that this thing is a great sadness that I would not wish upon anyone. It has destroyed so many lives, and none yet know the full extent of the damage it has wrought.