Created
November 26, 2014 21:22
-
-
Save dcollien/6759e46cbca0acd5726c to your computer and use it in GitHub Desktop.
MongoAlchemy shortcut layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mongoalchemy.fields import DocumentField, StringField | |
from mongoalchemy.session import Session | |
from mongoalchemy.query import BadQueryException, BadResultException | |
from mongoalchemy.query_expression import QueryField | |
from mongoalchemy.document import DocumentMeta | |
from pymongo.errors import DuplicateKeyError | |
from pytz import utc | |
import functools | |
import types | |
# The magical database shortcut layer | |
""" | |
e.g. | |
db.get(document_type, filter_query=None, fields=None, id=None, Optional: default='My Default') | |
db.pymongo(document_type) | |
db.eval('JS Code') | |
db.aggregate(document_type, pipeline) | |
db.SomeDocument.filter(...) | |
db[SomeDocument].filter(...) | |
db.save(doc) | |
db.remove(doc) | |
@db | |
class SomeDocumentCollection(Document): | |
pass | |
""" | |
class DBSession(object): | |
def __init__(self): | |
# the database session/connection (lazily created) | |
self._session = None | |
# a lookup of the document classes that should be stored in this db | |
self._models = {} | |
def _polymorphic(self, document_type, query): | |
""" Filter a query down to only the documents that are the given document class, | |
or subclasses of that type """ | |
identities = [document_type.config_polymorphic_identity] | |
identity_field = getattr(document_type, document_type.config_polymorphic) | |
for subclass in document_type.__subclasses__(): | |
identity = getattr(subclass, 'config_polymorphic_identity', None) | |
if identity is not None: | |
identities.append(identity) | |
else: | |
raise NotImplementedError("No polymorphic identity field " + str(document_type.config_polymorphic_identity) + " found on " + str(subclass.__name__)) | |
return query.filter(identity_field.in_(*identities)) | |
def get(self, document_type, filter_query=None, fields=None, id=None, **kwargs): | |
""" Get a single document from the db, of the given document class, | |
the id of the document, or a filter query can be given | |
additionally, the fields to extract can be given (as a list) | |
and a 'default' argument can be specified, to return if no matching | |
document is found | |
if it's a polymorphic document, it filters according to | |
its subclasses | |
e.g. | |
db.get(SomeDocument, id='someID', fields=['fieldA', 'fieldB'], default=None) | |
or | |
db.get(SomeDocument, filter_query=(SomeDocument.fieldA == 42, SomeDocument.fieldB == 27)) | |
""" | |
query = self.session.query(document_type) | |
if document_type.config_polymorphic_collection: | |
query = self._polymorphic(document_type, query) | |
if filter_query: | |
if isinstance(filter_query, list) \ | |
or isinstance(filter_query, tuple) \ | |
or isinstance(filter_query, types.GeneratorType): | |
query = query.filter(*filter_query) | |
else: | |
query = query.filter(filter_query) | |
if id is not None: | |
query = query.filter(document_type.mongo_id == id) | |
if fields: | |
query = query.fields(*fields) | |
try: | |
return query.one() | |
except BadResultException as err: | |
if 'default' not in kwargs: | |
raise err | |
else: | |
return kwargs['default'] | |
def eval(self, code, *args): | |
return self.session.db.eval(code, *args) | |
def aggregate(self, collection, pipeline): | |
return self.pymongo(collection).aggregate(pipeline) | |
def pymongo(self, collection): | |
""" Break out into pymongo for a given collection name, or document class | |
e.g. | |
db.pymongo(SomeDocument).pymongo_command | |
""" | |
if not isinstance(collection, basestring): | |
if hasattr(collection, 'get_collection_name'): | |
collection = collection.get_collection_name() | |
else: | |
collection = collection.__name__ | |
return self.session.db[collection] | |
def __getattr__(self, attr): | |
if attr.startswith('_'): | |
""" _private methods and properties are excluded from this method """ | |
raise AttributeError | |
if attr == 'session': | |
""" db.session lazily creates a _session connection, if one doesn't exist | |
otherwise returns the existing _session | |
""" | |
if self._session is None: | |
replica_set = getattr(settings, 'MONGO_REPLICA_SET', None) | |
read_preference = getattr(settings, 'MONGO_READ_PREFERENCE', None) | |
host = getattr(settings, 'MONGO_HOST', None) | |
port = getattr(settings, 'MONGO_PORT', None) | |
kwargs = {} | |
if host is not None: | |
kwargs['host'] = host | |
if port is not None: | |
kwargs['port'] = port | |
if replica_set is not None: | |
kwargs['replica_set'] = replica_set | |
if read_preference is not None: | |
kwargs['read_preference'] = read_preference | |
self._session = Session.connect( | |
settings.MONGO_DATABASE_NAME, | |
timezone=utc, | |
tz_aware=True, | |
safe=True, | |
**kwargs | |
) | |
return self._session | |
elif attr in ('save', 'remove', 'remove_query'): | |
""" | |
pass through db.save() and db.remove() to | |
db.session.save() and db.session.remove() respectively | |
""" | |
return getattr(self.session, attr) | |
else: | |
""" | |
db.DocumentClassName builds a query for that document, if | |
it's registered. Otherwise it looks up documents in the 'DocumentClassName' | |
collection anyway. | |
if it's a polymorphic document, it filters according to | |
its subclasses | |
""" | |
if attr in self._models: | |
document_type = self._models[attr] | |
query = self.session.query(document_type) | |
if document_type.config_polymorphic_collection: | |
query = self._polymorphic(document_type, query) | |
else: | |
query = self.session.query(attr) | |
return query | |
def __getitem__(self, key): | |
""" | |
db[DocumentClass] or db['DocumentClassName'] do the same thing as | |
db.DocumentClassName | |
... kinda like this is now JavaScript | |
""" | |
if isinstance(key, basestring): | |
return getattr(self, key) | |
else: | |
document_type = key | |
query = self.session.query(document_type) | |
if document_type.config_polymorphic_collection: | |
query = self._polymorphic(document_type, query) | |
return query | |
def __call__(self, model): | |
""" | |
Register a document collection by calling db(DocumentClass), | |
this returns the DocumentClass, so it can be used as a decorator: | |
e.g. | |
@db | |
class MyDocument(Document): | |
pass | |
""" | |
self._models[model.__name__] = model | |
return model | |
# pylint: disable=C0103 | |
db = DBSession() | |
def polymorphic(cls=None, type_field='_type'): | |
""" Decorator for polymorphic document classes, | |
also automatically registers the @db decorator | |
""" | |
if cls is None: | |
return functools.partial(polymorphic, type_field=type_field) | |
class PolymorphicMeta(DocumentMeta): | |
def __new__(mcs, name, bases, dct): | |
dct['config_polymorphic_collection'] = True | |
if 'config_polymorphic_identity' not in dct: | |
dct['config_polymorphic_identity'] = name | |
if 'config_polymorphic' not in dct: | |
dct['config_polymorphic'] = type_field | |
dct[dct['config_polymorphic']] = StringField(default=dct['config_polymorphic_identity']) | |
return super(PolymorphicMeta, mcs).__new__(mcs, name, bases, dct) | |
cls_dict = dict(cls.__dict__) | |
cls_dict["__metaclass__"] = PolymorphicMeta | |
cls_dict["__wrapped__"] = cls | |
return db(PolymorphicMeta(str(cls.__name__), tuple(cls.__bases__), cls_dict)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment