Created
October 14, 2013 22:29
-
-
Save jaytaylor/6983333 to your computer and use it in GitHub Desktop.
Django auto-sharding core.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # encoding: utf-8 | |
| """ | |
| Exports for main models and also a little bit of injection magic for | |
| django.contrib.auth.models.User. | |
| """ | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| # NB: ShardedUser and ShardedUserManager transparently extend the django User and | |
| # UserManager classes and add distributed query support. | |
| from django.contrib.auth import models as _djangoModels | |
| from .MainModel import MainModel, MainManager | |
| # Keep a backup reference to the original User and UserManager classes. | |
| _djangoModels._User = _djangoModels.User | |
| _djangoModels._UserManager = _djangoModels.UserManager | |
| #class _DebugModelState(object): | |
| # def __init__(self, db=None): | |
| # self._db = db | |
| # self.adding = True | |
| # | |
| # from sh_util import findVariableByNameInFrame | |
| # | |
| # self.parent = findVariableByNameInFrame('self') | |
| # | |
| # def _setDb(self, db): | |
| # if self.parent is not None and hasattr(self.parent, '_meta') and \ | |
| # hasattr(self.parent._meta, 'db_table') and \ | |
| # (self.parent._meta.db_table == 'main_contact' or | |
| # self.parent._meta.db_table == 'main_extendeduser'): | |
| # from django_util.log_errors import print_stack | |
| # print ':::::::: Setting db to {0} for {1}={2}'.format( | |
| # db, | |
| # self.parent._meta.db_table if self.parent is not None \ | |
| # else 'PARENTISNONE', | |
| # self.parent.id if self.parent is not None else 'PARENTISNONE' | |
| # ) | |
| # print_stack() | |
| # print '------' | |
| # self._db = db | |
| # | |
| # def _getDb(self): | |
| # return self._db | |
| # | |
| # db = property(_getDb, _setDb) | |
| # | |
| #from django.db.models import base as _base | |
| #_base.ModelState = _DebugModelState | |
| class ShardedUserManager(_djangoModels._UserManager, MainManager): | |
| """Override default django User manager.""" | |
| class ShardedUser(_djangoModels._User): | |
| """ShUser extends the base django.contrib.auth.models.User class.""" | |
| class Meta(object): | |
| """ | |
| Part of the django models interface. | |
| Here is where ShUser is designated as a simple proxy of the django User | |
| class. | |
| """ | |
| # ShUser is a proxy of the django User class. | |
| # @see https://docs.djangoproject.com/en/dev/topics/db/models/ | |
| proxy = True | |
| objects = ShardedUserManager() | |
| def delete(self, *args, **kw): | |
| """Custom delete.""" | |
| from sh_util.db.data import deleteUser | |
| deleteUser(self.id, self._state.db) | |
| # Inject Sh* proxy classes in place of the django models. | |
| _djangoModels.User = ShardedUser | |
| _djangoModels.UserManager = ShardedUserManager |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """LogicalShard model.""" | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| from django.db.models import ( | |
| CharField, | |
| DateTimeField, | |
| IntegerField, | |
| TextField, | |
| ) | |
| from .MainModel import MainModel | |
| from sh_util.db import db_query | |
| import settings | |
| class LogicalShard(MainModel): | |
| """LogicalShard model.""" | |
| class Meta: | |
| """Default Meta.""" | |
| app_label = '' | |
| db_table = 'LogicalShard' | |
| physicalShardId = IntegerField() | |
| status = CharField(max_length=64, null=False, blank=True, default='') | |
| note = TextField(null=False, blank=True, default='') | |
| createdTs = DateTimeField(auto_now_add=True) | |
| modifiedTs = DateTimeField(auto_now=True, null=True, default=None) | |
| def userIds(self): | |
| """ | |
| Get all the user-ids in a logical shard. | |
| @return list(int) of user-ids. | |
| """ | |
| res = db_query( | |
| '''SELECT "id" FROM "auth_user" WHERE "id" %% {0} = {1}'''.format(settings.NUM_LOGICAL_SHARDS, self.id), | |
| using='shard_{0}'.format(self.physicalShardId) | |
| ) | |
| userIds = map(lambda tup: tup[0], res) | |
| return userIds |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # encoding: utf-8 | |
| """All models in `main` should extend this class.""" | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| import datetime, simplejson as json | |
| from django.db import transaction | |
| from django.db.models.fields.files import ImageFieldFile | |
| from .sharding import distributed | |
| ################################################################################ | |
| # Useful for debugging sharding issues w/ ORM: # | |
| ################################################################################ | |
| # Override django.db.models.base.ModelState | |
| #from django.db.models import base as _base | |
| ## Keep reference to original ModelState class. | |
| #_base.ModelStateOrig = _base.ModelState | |
| # | |
| #class ShardAwareModelState(_base.ModelStateOrig): | |
| # """A class for storing instance state. Overrides default django class.""" | |
| # def __init__(self, db=None): | |
| # """Invoke parent constructor.""" | |
| # super(ShardAwareModelState, self).__init__(db) | |
| # | |
| # def _getDb(self): | |
| # """Return self.db.""" | |
| # return self._db | |
| # | |
| # def _setDb(self, db): | |
| # """Set new value for self.db.""" | |
| # from sh_util.sharding import ShardedResource | |
| # | |
| ## if db == 'default': | |
| ## db = ShardedResource.getCurrentShard() | |
| # | |
| # print 'SET DB TO %s' % db | |
| # self._db = db | |
| # | |
| # db = property(_getDb, _setDb) | |
| # | |
| #_base.ModelState = ShardAwareModelState | |
| class BulkInsertable(object): | |
| """Bulk insertable trait.""" | |
| @staticmethod | |
| @transaction.commit_manually | |
| def insertMany(objects, using='default'): | |
| """ | |
| Note: Modified From http://people.iola.dk/olau/python/bulkops.py | |
| Insert list of Django objects in one SQL query. Objects must be | |
| of the same Django model. Note that save is not called and signals | |
| on the model are not raised. | |
| """ | |
| if not objects: | |
| return | |
| import django.db.models | |
| from django.db import connections | |
| con = connections[using] | |
| model = objects[0].__class__ | |
| fields = [f for f in model._meta.fields if not isinstance(f, django.db.models.AutoField)] | |
| parameters = [] | |
| for o in objects: | |
| parameters.append( | |
| tuple( | |
| f.get_db_prep_save(f.pre_save(o, True), connection=con) for f in fields | |
| ) | |
| ) | |
| table = model._meta.db_table | |
| columns = ','.join(con.ops.quote_name(f.column) for f in fields) | |
| placeholders = ','.join(('%s',) * len(fields)) | |
| #print 'table', table, 'columns', column_names, 'placeholders', \ | |
| # placeholders, 'con', con | |
| con.cursor().executemany( | |
| 'INSERT INTO "{0}" ({1}) VALUES ({2})'.format( | |
| table, | |
| columns, | |
| placeholders | |
| ), | |
| parameters | |
| ) | |
| from sh_util.db import db_exec | |
| db_exec('COMMIT', using=using) | |
| #transaction.commit() | |
| class MainManager(distributed.DistributedManager, BulkInsertable): | |
| """Abstract base class for Managers in app.main.""" | |
| class SoftDeleteCopyable(object): | |
| """Soft-delete deep-copy trait.""" | |
| def __deepcopy__(self, memo): | |
| """Deep copy of a QuerySet doesn't populate the cache.""" | |
| import copy | |
| obj = self.__class__(model=self.model, query=self.query, using=self._db) | |
| items = filter( | |
| lambda (k, v): k not in ('model', 'query', '_db'), | |
| self.__dict__.items() | |
| ) | |
| for k, v in items: | |
| if k in ('_iter', '_result_cache'): | |
| obj.__dict__[k] = None | |
| else: | |
| obj.__dict__[k] = copy.deepcopy(v, memo) | |
| return obj | |
| class SoftDeleteDistributedQuerySet(SoftDeleteCopyable, distributed.DistributedQuerySet): | |
| """Soft-delete support for distributed query sets.""" | |
| def __init__(self, model=None, query=None, using=None): | |
| """Automatically add deleted filter in constructor.""" | |
| super(SoftDeleteDistributedQuerySet, self).__init__( | |
| model=model, | |
| query=query, | |
| using=using | |
| ) | |
| from django.db.models.query_utils import Q | |
| if query is None: | |
| self.query = distributed.DistributedQuery(model) | |
| if 'deleted' in map(lambda field: field.name, model._meta.fields): | |
| self.query.add_q(Q(deleted=False)) | |
| def cancel(self): | |
| """Reset the query.""" | |
| clone = self._clone() | |
| clone.query = distributed.DistributedQuery(clone.model) | |
| return clone | |
| class SoftDeleteAutoShardingQuerySet(SoftDeleteCopyable, distributed.AutoShardingQuerySet): | |
| """Soft-delete support for auto-sharding query sets.""" | |
| def __init__(self, model=None, query=None, using=None): | |
| """Automatically add deleted filter in constructor.""" | |
| super(SoftDeleteAutoShardingQuerySet, self).__init__(model=model, query=query, using=using) | |
| if query is None: | |
| from django.db.models.query_utils import Q | |
| from django.db.models import sql | |
| self.query = sql.Query(model) | |
| if 'deleted' in map(lambda field: field.name, model._meta.fields): | |
| self.query.add_q(Q(deleted=False)) | |
| def cancel(self): | |
| """Reset the query.""" | |
| from django.db.models import sql | |
| clone = self._clone() | |
| clone.query = sql.Query(clone.model) | |
| return clone | |
| class SoftDeleteManager(distributed.DistributedManager, BulkInsertable): | |
| """ | |
| Manager base class for models which support soft delete (via `deleted` | |
| field). | |
| """ | |
| def __init__(self, *args, **kw): | |
| """Override default distributedQuerySetImpl.""" | |
| super(SoftDeleteManager, self).__init__(*args, **kw) | |
| self.distributedQuerySetImpl = SoftDeleteDistributedQuerySet | |
| self.autoShardingQuerySetImpl = SoftDeleteAutoShardingQuerySet | |
| def deleted(self, **kw): | |
| """Returns QuerySet with delete filter.""" | |
| querySet = self.all( | |
| **({'distributed': kw['distributed']} if 'distributed' in kw else {}) | |
| ).cancel() | |
| return querySet.filter(deleted=True) | |
| def all_with_deleted(self, using=None, **kw): | |
| """Return QuerySet without delete filter.""" | |
| querySet = self.all( | |
| **({'distributed': kw['distributed']} if 'distributed' in kw else {}) | |
| ).cancel() | |
| if using is not None: | |
| querySet._db = using | |
| # Check for get_or_create_with_deleted override and if one is found | |
| # replace the QuerySet's get_or_create with it. | |
| if hasattr(self, 'get_or_create_with_deleted'): | |
| querySet.get_or_create = self.get_or_create_with_deleted | |
| return querySet | |
| ## | |
| # NB: | |
| # This turned out to be unnecesary, but it was sucha PITA to implement I'm | |
| # leaving it here commented out. | |
| ################################################################################ | |
| #def _attrIsSubclassOfField(instance, attrName): | |
| # """ | |
| # @return True if the named attribute on the instance's class is a subclass | |
| # of Field, otherwise False. | |
| # """ | |
| # from django.db.models.fields import Field | |
| # | |
| # if not hasattr(instance.__class__, attrName): | |
| # return False | |
| # | |
| # attr = getattr(instance.__class__, attrName) | |
| # | |
| # if issubclass(attr.__class__, Field) or ( | |
| # hasattr(attr, 'field') and | |
| # issubclass(getattr(attr, 'field').__class__, Field) | |
| # ): | |
| # return True | |
| # | |
| # return False | |
| #_ignoreAttrs = ('_meta', '_state', 'pk', 'id',) | |
| _jsonIgnoreFields = ( | |
| 'objects', | |
| ) | |
| class MainModel(distributed.DistributedModel): | |
| """Abstract base class for Models in app.main.""" | |
| class Meta: | |
| """Default Meta.""" | |
| abstract = True | |
| app_label = 'main' | |
| ## | |
| # NB: | |
| # This turned out to be unnecesary, but it was sucha PITA to implement I'm | |
| # leaving it here commented out. | |
| ################################################################################ | |
| # def __getattribute__(self, k): | |
| # """d""" | |
| # from django.db.models.fields import Field | |
| # | |
| # if k.startswith('__') or k in _ignoreAttrs: | |
| # return super(MainModel, self).__getattribute__(k) | |
| # | |
| # print 'hi %s hasattr?%s' % (k, hasattr(self.__class__, k)) | |
| # | |
| # if _attrIsSubclassOfField(self, k) is True: | |
| # from sh_util.sharding import ShardedConnection | |
| # | |
| # with ShardedConnection(self._state.db) as _: | |
| # print 'HIIIIIIIIIIIIIIIIIIIIIII!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' | |
| # return super(MainModel, self).__getattribute__(k) | |
| # | |
| # else: | |
| # out = super(MainModel, self).__getattribute__(k) | |
| # print type(out), out.__class__ | |
| # return super(MainModel, self).__getattribute__(k) | |
| @property | |
| def __json__(self): | |
| """Model to simple dict which can be serialized to JSON.""" | |
| attrs = dict( | |
| filter(lambda (k, v): not k.startswith('_'), self.__dict__.items()) | |
| ) | |
| # Find and add any *_str attribute string aliases. | |
| for attr in attrs.keys(): | |
| probe = '{0}_str'.format(attr) | |
| if probe not in attrs and hasattr(self, probe) and not callable(getattr(self, probe)): | |
| attrs[probe] = getattr(self, probe) | |
| # Convert things into JSON when we know how. | |
| for k in attrs.keys(): | |
| tK = type(attrs[k]) | |
| #print 'k={0}, type(k)={1}, attrs[k]={2}'.format(k, tK, attrs[k]) | |
| if attrs[k] is None: | |
| # Don't include nulls. | |
| del attrs[k] | |
| elif tK is int or tK is long: | |
| # All numbers will be coerced into strings. | |
| attrs[k] = str(attrs[k]) | |
| elif hasattr(attrs[k], '__json__') and callable(attrs[k].__json__): | |
| attrs[k] = attrs[k].__json__() | |
| # If it came back as just a string, let's seeing if it's JSON.. | |
| if isinstance(attrs[k], str) or isinstance(attrs[k], unicode): | |
| try: | |
| attrs[k] = json.loads(attrs[k].__json__()) | |
| except json.decoder.JSONDecodeError: | |
| pass | |
| elif hasattr(attrs[k], '__json__'): | |
| # Not callable. | |
| attrs[k] = json.loads(attrs[k].__json__) | |
| elif tK is datetime.datetime: | |
| attrs[k] = attrs[k].strftime('%Y-%m-%dT%H:%M:%S') | |
| elif tK is datetime.date: | |
| attrs[k] = attrs[k].strftime('%Y-%m-%d') | |
| elif tK is ImageFieldFile: | |
| attrs[k] = attrs[k].url if attrs[k] else '' | |
| return attrs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """Django model utilities.""" | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| def generateDistributedCachingProperty( | |
| idFieldName, | |
| modelClassName, | |
| basePath='main.models' | |
| ): | |
| """ | |
| @return a property of getter/setter functions which will do a distributed | |
| get on the requested model class. | |
| """ | |
| from sh_util import toId | |
| def getter(self): | |
| """Generated function.""" | |
| from sh_util import dynImport | |
| model = dynImport('main.models.{0}'.format(modelClassName)) | |
| cacheVarName = '__{0}'.format(modelClassName) | |
| if cacheVarName not in self.__dict__: | |
| setattr(self, cacheVarName, None) | |
| cache = getattr(self, cacheVarName) | |
| currentId = getattr(self, idFieldName) | |
| if currentId is None: | |
| return None | |
| if cache is not None and cache.id == currentId: | |
| return cache | |
| cache = model.objects.get(distributed=True, id=currentId) | |
| setattr(self, cacheVarName, cache) | |
| return cache | |
| return property( | |
| getter, | |
| lambda self, value: setattr(self, idFieldName, toId(value)) | |
| ) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # encoding: utf-8 | |
| """Distributed trait for django models.""" | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| import logging | |
| import settings | |
| from django.db import connections | |
| from django.db.models.sql.constants import SINGLE | |
| from django.db.models.sql.datastructures import EmptyResultSet | |
| from django.db.models import Manager, Model | |
| from django.db.models.query import QuerySet | |
| from django.db.models.sql.compiler import SQLCompiler as SqlCompiler, SQLAggregateCompiler as SqlAggregateCompiler | |
| from django.db.models.sql.query import Query | |
| from django.db.models.query import get_klass_info, get_cached_row | |
| from django.db.models.sql.constants import MULTI | |
| from sh_util.db import getRealShardConnectionName, connections as shConnections | |
| from sh_util.db.distributed import distributedSelect, pgInitializeDbLinks | |
| from sh_util.functional import memoize | |
| ## NB: ``Distributed`` class is deprecated | |
| #_tailOperators = { | |
| # 'eq': '=', | |
| # 'ne': '!=', | |
| # 'gt': '>', | |
| # 'gte': '>=', | |
| # 'lt': '<', | |
| # 'lte': '<=', | |
| # 'in': 'IN', | |
| # 'nin': 'NOT IN', | |
| # 'contains': 'LIKE', | |
| # 'like': 'LIKE', | |
| #} | |
| # | |
| #class Distributed(object): | |
| # """Distributed trait.""" | |
| # | |
| # @classmethod | |
| # def buildFromDict(clazz, **kwargs): | |
| # """Build a contact parent object from a dict.""" | |
| # db = kwargs.pop('shard') | |
| # | |
| # obj = clazz(**kwargs) | |
| # | |
| # obj._state.db = db | |
| # | |
| # return obj | |
| # | |
| # @staticmethod | |
| # def _parseFilter(**kw): | |
| # """Convert keyword arguments into a where clause.""" | |
| # keyOpValueList = [] | |
| # | |
| # for key in kw: | |
| # value = kw[key] | |
| # | |
| # operator = '=' | |
| # | |
| # for tail in _tailOperators: | |
| # if key.endswith('__{0}'.format(tail)): | |
| # operator = _tailOperators[tail] | |
| # | |
| # key = key[0:key.rindex('__')] | |
| # | |
| # if tail == 'contains': | |
| # value = '%{0}%'.format(value) | |
| # | |
| # break | |
| # | |
| # if operator is '=' and (value is None or type(value) is bool): | |
| # operator = 'IS' | |
| # elif operator is '!=' and (value is None or type(value) is bool): | |
| # operator = 'IS NOT' | |
| # | |
| # if value is None: | |
| # value = 'NULL' | |
| # elif isinstance(value, int) or isinstance(value, long): | |
| # value = str(value) | |
| # elif isinstance(value, bool): | |
| # value = str(value).upper() | |
| # elif operator == 'IN' and hasattr(value, '__iter__'): | |
| # if len(value) > 0 and (isinstance(value[0], int) or \ | |
| # isinstance(value[0], long) or isinstance(value[0], bool)): | |
| # value = '({0})'.format(', '.join(map(str, value))) | |
| # else: | |
| # value = '({0})'.format(', '.join(map( | |
| # lambda s: "'{0}'".format(str(s)), | |
| # value | |
| # ))) | |
| # else: | |
| # print 'yEP' | |
| # value = "'{0}'".format(value.strip("'")) | |
| # | |
| # keyOpValueList.append((key, operator, value)) | |
| # | |
| # out = ' AND '.join(map( | |
| # lambda (key, op, value): '"{0}" {1} {2}' \ | |
| # .format(key, op, value), | |
| # keyOpValueList | |
| # )) | |
| # print 'OUT=%s' % out | |
| # return out | |
| # | |
| # @classmethod | |
| # def distributedFilter(clazz, **kw): | |
| # """Distributed filter.""" | |
| # where = Distributed._parseFilter(**kw) | |
| # if len(where): | |
| # where = 'WHERE {0}'.format(where) | |
| # | |
| # table = clazz()._meta.db_table | |
| # | |
| # data = _evaluatedDistributedSelect( | |
| # '''SELECT * FROM "{0}" {1}'''.format(table, where), | |
| # asDict=True, | |
| # includeShardInfo=True | |
| # ) | |
| # | |
| # def dataToObj(data): | |
| # """Convert data to a class instance.""" | |
| # # Build and return obj. | |
| # db = data.pop('shard') | |
| # | |
| # obj = clazz(**data) | |
| # | |
| # obj._state.db = db | |
| # | |
| # return obj | |
| # | |
| # return map(dataToObj, data) | |
| # | |
| # @classmethod | |
| # def distributedGet(clazz, **kw): | |
| # """Distributed get.""" | |
| # where = Distributed._parseFilter(**kw) | |
| # if len(where): | |
| # where = 'WHERE {0}'.format(where) | |
| # | |
| # table = clazz()._meta.db_table | |
| # | |
| # data = _evaluatedDistributedSelect( | |
| # '''SELECT * FROM "{0}" {1}'''.format(table, where), | |
| # asDict=True, | |
| # includeShardInfo=True | |
| # ) | |
| # | |
| # if len(data) != 1: | |
| # raise clazz.DoesNotExist() | |
| # | |
| # def dataToObj(data): | |
| # """Convert data to a class instance.""" | |
| # # Build and return obj. | |
| # db = data.pop('shard') | |
| # | |
| # obj = clazz(**data) | |
| # | |
| # obj._state.db = db | |
| # | |
| # return obj | |
| # | |
| # return dataToObj(data[0]) | |
| class DistributedSqlCompiler(SqlCompiler): | |
| """Distributed pg compiler.""" | |
| def as_sql( | |
| self, | |
| with_limits=True, | |
| with_col_aliases=False, | |
| includeShardInfo=True | |
| ): | |
| """Creates the SQL for this query.""" | |
| sql, args = super(DistributedSqlCompiler, self).as_sql( | |
| with_limits=with_limits, | |
| with_col_aliases=with_col_aliases | |
| ) | |
| distributedSql, distributedArgs = distributedSelect( | |
| sql, | |
| args, | |
| includeShardInfo=includeShardInfo | |
| ) | |
| return (distributedSql, distributedArgs) | |
| def execute_sql(self, result_type=MULTI): | |
| """If persistent dblink is enabled, ensure persistent dblinks are initialized.""" | |
| self.using = getRealShardConnectionName(self.using) | |
| self.connection = shConnections()[self.using] | |
| if getattr(settings, 'SH_UTIL_USE_PERSISTENT_DBLINK', False): | |
| pgInitializeDbLinks(self.using) | |
| return super(DistributedSqlCompiler, self).execute_sql(result_type) | |
| class DistributedSqlAggregateCompiler( | |
| SqlAggregateCompiler, | |
| DistributedSqlCompiler | |
| ): | |
| """"Distributed aggregate SQL compiler.""" | |
| def as_sql(self, qn=None): | |
| """Only use the AggregateCompiler when self.subquery has been set.""" | |
| # Only use aggregate compiler when self.subquery exists. | |
| if hasattr(self, 'subquery'): | |
| # Get aggregate SQL from SqlAggregateCompiler. | |
| sql, args = super(DistributedSqlAggregateCompiler, self).as_sql(qn=qn) | |
| distributedSql, distributedArgs = distributedSelect(sql, args, includeShardInfo=False) | |
| else: | |
| distributedSql, distributedArgs = DistributedSqlCompiler.as_sql(self, False, False, False) | |
| return (distributedSql, distributedArgs) | |
| # Inject our compilers into django.db.models.sql.compiler so they will be | |
| # available to django. | |
| from django.db.models.sql import compiler as _compiler | |
| _compiler.DistributedSqlCompiler = DistributedSqlCompiler | |
| _compiler.DistributedSqlAggregateCompiler = DistributedSqlAggregateCompiler | |
| class DistributedQuery(Query): | |
| """Distributed Query.""" | |
| compiler = 'DistributedSqlCompiler' | |
| aggregateCompiler = 'DistributedSqlAggregateCompiler' | |
| def get_count(self, using): | |
| """ | |
| Performs a COUNT() query using the current filter constraints. | |
| Pasted this from django 1.4 and then override the compiler inside. | |
| """ | |
| obj = self.clone() | |
| if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields): | |
| # If a select clause exists, then the query has already started to | |
| # specify the columns that are to be returned. | |
| # In this case, we need to use a subquery to evaluate the count. | |
| from django.db.models.sql.subqueries import AggregateQuery | |
| subquery = obj | |
| subquery.clear_ordering(True) | |
| subquery.clear_limits() | |
| obj = AggregateQuery(obj.model) | |
| # Set compiler to be the distributed aggregate query compiler. | |
| obj.compiler = DistributedQuery.aggregateCompiler | |
| try: | |
| obj.add_subquery(subquery, using=using) | |
| except EmptyResultSet: | |
| # add_subquery evaluates the query, if it's an EmptyResultSet | |
| # then there are can be no results, and therefore there the | |
| # count is obviously 0 | |
| return 0 | |
| else: | |
| # Set compiler to be the distributed aggregate query compiler. | |
| obj.compiler = DistributedQuery.aggregateCompiler | |
| obj.add_count_column() | |
| number = obj.get_aggregation(using=using)[None] | |
| # Apply offset and limit constraints manually, since using LIMIT/OFFSET | |
| # in SQL (in variants that provide them) doesn't change the COUNT | |
| # output. | |
| number = max(0, number - self.low_mark) | |
| if self.high_mark is not None: | |
| number = min(number, self.high_mark - self.low_mark) | |
| return number | |
| def get_aggregation(self, using): | |
| """ | |
| Returns the dictionary with the values of the existing aggregations. | |
| Copied straight out of django 1.4 django.db.models.sql.query.Query | |
| in order to override which compiler is used. | |
| """ | |
| if not self.aggregate_select: | |
| return {} | |
| # If there is a group by clause, aggregating does not add useful | |
| # information but retrieves only the first row. Aggregate | |
| # over the subquery instead. | |
| if self.group_by is not None: | |
| from django.db.models.sql.subqueries import AggregateQuery | |
| query = AggregateQuery(self.model) | |
| # Set compiler to be the distributed aggregate query compiler. | |
| query.compiler = DistributedQuery.aggregateCompiler | |
| obj = self.clone() | |
| # Remove any aggregates marked for reduction from the subquery | |
| # and move them to the outer AggregateQuery. | |
| for alias, aggregate in self.aggregate_select.items(): | |
| if aggregate.is_summary: | |
| query.aggregate_select[alias] = aggregate | |
| del obj.aggregate_select[alias] | |
| try: | |
| query.add_subquery(obj, using) | |
| except EmptyResultSet: | |
| return dict( | |
| (alias, None) | |
| for alias in query.aggregate_select | |
| ) | |
| else: | |
| query = self | |
| self.select = [] | |
| self.default_cols = False | |
| self.extra = {} | |
| self.remove_inherited_models() | |
| # Set compiler to be the distributed aggregate query compiler. | |
| query.compiler = DistributedQuery.aggregateCompiler | |
| query.clear_ordering(True) | |
| query.clear_limits() | |
| query.select_for_update = False | |
| query.select_related = False | |
| query.related_select_cols = [] | |
| query.related_select_fields = [] | |
| result = query.get_compiler(using).execute_sql(SINGLE) | |
| if result is None: | |
| result = [None for _ in query.aggregate_select.items()] | |
| return dict([ | |
| ( | |
| alias, | |
| self.resolve_aggregate( | |
| val, | |
| aggregate, | |
| connection=connections[using] | |
| ) | |
| ) | |
| for (alias, aggregate), val | |
| in zip(query.aggregate_select.items(), result) | |
| ]) | |
| @memoize | |
| def _numShards(): | |
| """Cached number of physical shards.""" | |
| from sh_util.sharding import ShardedResource | |
| return len(ShardedResource.allShardConnectionNames()) | |
| class DistributedQuerySet(QuerySet): | |
| """ | |
| Distributed Query Set. | |
| Supports a limited-subset of the usual QuerySet functionality. | |
| """ | |
| def __init__(self, model=None, query=None, using=None): | |
| """Override query's compiler.""" | |
| if query is None: | |
| query = DistributedQuery(model=model) | |
| super(DistributedQuerySet, self).__init__( | |
| model=model, | |
| query=query, | |
| using=using | |
| ) | |
| def iterator(self): | |
| """Iterate and yield each of the resulting rows.""" | |
| def buildFromArgsOrKw(clazz, *args, **kwargs): | |
| """Initialize a db class instance from a args or kwargs.""" | |
| # Must be one or the other. | |
| assert (len(args) == 0 and len(kwargs) > 0) or (len(args) > 0 and len(kwargs) == 0) | |
| # NB: Extreme care must be taken here when attempting to ascertain | |
| # the shard value. | |
| if len(args) > 0: | |
| if _numShards() == 1: | |
| # This is okay, but only if there is only 1 physical shard. | |
| shard = self.db | |
| else: | |
| shard = args[-1] | |
| args = args[0:-1] | |
| obj = clazz(*args) | |
| else: | |
| if _numShards() == 1: | |
| # This is okay, but only if there is only 1 physical shard. | |
| shard = self.db | |
| else: | |
| shard = kwargs.pop('shard') | |
| obj = clazz(**kwargs) | |
| # Store the source database of the object. | |
| obj._state.db = shard | |
| # This object came from the database; it's not being added. | |
| obj._state.adding = False | |
| return obj | |
| if connections[self.db].features.supports_select_related: | |
| fillCache = self.query.select_related | |
| else: | |
| fillCache = False | |
| requested = fillCache if isinstance(fillCache, dict) else None | |
| maxDepth = self.query.max_depth | |
| extraSelect = self.query.extra_select.keys() | |
| aggregateSelect = self.query.aggregate_select.keys() | |
| onlyLoad = self.query.get_loaded_field_names() | |
| def identifyLoadFields(onlyLoad): | |
| """Determine which fields will be loaded.""" | |
| loadFields = [] | |
| if onlyLoad: | |
| for field, model in self.model._meta.get_fields_with_model(): | |
| if model is None: | |
| model = self.model | |
| try: | |
| if field.name in onlyLoad[model]: | |
| # Add a field that has been explicitly included. | |
| loadFields.append(field.name) | |
| except KeyError: | |
| # Model wasn't explicitly listed in the onlyLoad table | |
| # Therefore, we need to load all fields from this model. | |
| loadFields.append(field.name) | |
| return loadFields | |
| loadFields = identifyLoadFields(onlyLoad) | |
| if loadFields and not fillCache: | |
| skip = set() | |
| initList = [] | |
| for field in self.model._meta.fields: | |
| if field.name not in loadFields: | |
| skip.add(field.attname) | |
| else: | |
| initList.append(field.attname) | |
| if _numShards() > 1: | |
| initList.append('shard') | |
| else: | |
| skip = None | |
| initList = None | |
| numFields = len(loadFields or self.model._meta.fields) | |
| indexStart = len(extraSelect) | |
| # + 1 for shardId. | |
| aggregateStart = indexStart + numFields + 1 | |
| compiler = self.query.get_compiler(using=self.db) | |
| if fillCache: | |
| klassInfo = get_klass_info( | |
| self.model, | |
| max_depth=maxDepth, | |
| requested=requested, | |
| only_load=onlyLoad | |
| ) | |
| for row in compiler.results_iter(): | |
| if fillCache: | |
| obj, _ = get_cached_row( | |
| row, | |
| indexStart, | |
| self.db, | |
| klassInfo, | |
| offset=len(aggregateSelect) | |
| ) | |
| elif skip is not None: | |
| # Build from dict of args: | |
| rowData = row[indexStart:aggregateStart] | |
| obj = buildFromArgsOrKw( | |
| self.model, | |
| **dict(zip(initList, rowData)) | |
| ) | |
| else: | |
| # Build from list of args. | |
| obj = buildFromArgsOrKw( | |
| self.model, | |
| *row[indexStart:aggregateStart] | |
| ) | |
| if extraSelect: | |
| for i, k in enumerate(extraSelect): | |
| setattr(obj, k, row[i]) | |
| # Add the aggregates to the model. | |
| if aggregateSelect: | |
| for i, aggregate in enumerate(aggregateSelect): | |
| setattr(obj, aggregate, row[i + aggregateStart]) | |
| yield obj | |
| def get_query_set(self): | |
| """Return a distributed query set.""" | |
| return DistributedQuerySet( | |
| model=self.model, | |
| query=self.query, | |
| using=self._db | |
| ) | |
| # NB: MOVE THIS TO SOME CONFIGURABLE PLACE | |
| def shardKeyFinder(querySet, **kw): | |
| """ | |
| Part of the auto-sharding interface. This function takes a QuerySet and | |
| dict of keyword arguments and returns the shard-key value if possible, | |
| otherwise None. | |
| """ | |
| # Get at the model's table name. | |
| modelTable = querySet.model._meta.db_table if hasattr(querySet.model, '_meta') else querySet.model()._meta.db_table | |
| if modelTable == 'auth_user': | |
| # Search for "id" or "pk": | |
| #logging.debug(u'Searching for id/pk key in {0}'.format(kw)) | |
| spec = kw.get('id', kw.get('pk', None)) | |
| else: | |
| # Otherwise.. user*: | |
| #logging.debug(u'Searching for user-id key in {0}'.format(kw)) | |
| spec = kw.get('user', kw.get('user_id', kw.get('userId', None))) | |
| return spec | |
| # Operations which can be (at least in theory) automatically distributed. | |
| _distributable = ( | |
| 'none', | |
| 'all', | |
| 'count', | |
| 'dates', | |
| 'distinct', | |
| 'extra', | |
| 'get', | |
| 'filter', | |
| 'aggregate', | |
| 'annotate', | |
| 'complex_filter', | |
| 'exclude', | |
| 'in_bulk', | |
| 'iterator', | |
| 'latest', | |
| 'order_by', | |
| 'values', | |
| 'values_list', | |
| 'reverse', | |
| 'defer', | |
| 'only', | |
| 'exists', | |
| 'raw', | |
| ) | |
| # Operations for which the correct shard may be automatically inferred by id or | |
| # user-id. | |
| _autoShardable = ( | |
| 'create', | |
| 'get_or_create', | |
| 'get_or_create_with_deleted', | |
| ) | |
| class AutoSelectShard(object): | |
| """Automatic shard selection trait.""" | |
| def _findShardKeySpecification(self, **kw): | |
| """ | |
| Disover the user-id specifier. Uses **kw so the original kws wont be | |
| messed with. | |
| @return user-id value or None if not found. | |
| """ | |
| spec = shardKeyFinder(self, **kw) | |
| if spec is not None: | |
| if isinstance(spec, int) or isinstance(spec, long): | |
| return spec | |
| elif hasattr(spec, 'id'): | |
| return spec.id | |
| elif isinstance(spec, str) or isinstance(spec, unicode) and spec.isdigit(): | |
| return spec | |
| return None | |
| def _autoSelectShard(self, userId): | |
| """Automatically use the correct shard based on user-id.""" | |
| from sh_util.sharding import ShardedResource | |
| physicalShardId = ShardedResource.userIdToPhysicalShardId(userId) | |
| self._db = 'shard_{0}'.format(physicalShardId) | |
| logging.debug( | |
| u'DistributedManager :: {0} :: Automatically selected db={1} for user-id={2}'.format( | |
| self.model._meta.db_table if hasattr(self.model, '_meta') and hasattr(self.model._meta, 'db_table') else 'unknown', | |
| self._db, | |
| userId | |
| ) | |
| ) | |
| class AutoShardingQuerySet(AutoSelectShard, QuerySet): | |
| """Automatically selects shard if a shard-key specifier is found.""" | |
| def __getattribute__(self, name): | |
| """ | |
| Proxy for all attribute requests to inject distributed query request | |
| detection. | |
| """ | |
| attr = super(AutoShardingQuerySet, self).__getattribute__(name) | |
| # Allow any attribute contained in auto-shardable or distributable to | |
| # have the destination connection automatically selected. | |
| if name in _autoShardable or name in _distributable: | |
| def wrapped(*args, **kwargs): | |
| """@return The wrapped version of the attr.""" | |
| # Attempt to infer appropriate shard. | |
| specifier = self._findShardKeySpecification(**kwargs) | |
| if specifier is not None: | |
| self._autoSelectShard(specifier) | |
| attr = super(AutoShardingQuerySet, self).__getattribute__(name) | |
| result = attr(*args, **kwargs) | |
| return result | |
| return wrapped | |
| else: | |
| return attr | |
| class DistributedManager(AutoSelectShard, Manager): | |
| """ | |
| Distributed model manager. | |
| NB: One way to always disable the automatic shard selector is to use | |
| <user-or-id-field>__in=(x,) instead of <user-or-id-field>=x. | |
| """ | |
| def __init__(self, *args, **kw): | |
| """Set default dist queryset implementation.""" | |
| super(DistributedManager, self).__init__(*args, **kw) | |
| # These can be swapped out by child classes. | |
| self.distributedQuerySetImpl = DistributedQuerySet | |
| self.autoShardingQuerySetImpl = AutoShardingQuerySet | |
| def get_or_create(self, **kwargs): | |
| """ | |
| Looks up an object with the given kwargs, creating one if necessary. | |
| Returns a tuple of (object, created), where created is a boolean | |
| specifying whether an object was created. | |
| """ | |
| from django.db import IntegrityError | |
| import sys | |
| assert kwargs, 'get_or_create() must be passed at least one keyword argument' | |
| defaults = kwargs.pop('defaults', {}) | |
| lookup = kwargs.copy() | |
| for f in self.model._meta.fields: | |
| if f.attname in lookup: | |
| lookup[f.name] = lookup.pop(f.attname) | |
| try: | |
| return self.get(**lookup), False | |
| except self.model.DoesNotExist: | |
| try: | |
| params = dict([(k, v) for k, v in kwargs.items() if '__' not in k]) | |
| params.update(defaults) | |
| obj = self.create(**params) | |
| return obj, True | |
| except IntegrityError, e: | |
| exc_info = sys.exc_info() | |
| connections[self._db]._rollback() | |
| try: | |
| return self.get(**lookup), False | |
| except self.model.DoesNotExist: | |
| # Re-raise the IntegrityError with its original traceback. | |
| raise exc_info[1], None, exc_info[2] | |
| def get_query_set(self): | |
| """Get an auto-sharding capable query set instance.""" | |
| logging.debug(u'Returning auto-sharding query set for model={0}'.format(self.model)) | |
| return self.autoShardingQuerySetImpl(model=self.model, using=self._db) | |
| # @TODO .exists is useless here, can likely be removed. | |
| def exists(self): | |
| """More expensive but working form of .exists().""" | |
| return self.count() > 0 | |
| def __getattribute__(self, name): | |
| """ | |
| Proxy for all attribute requests to inject distributed query request | |
| detection. | |
| """ | |
| attr = super(DistributedManager, self).__getattribute__(name) | |
| if name in _distributable and callable(attr): | |
| logging.debug(u'__getattribute__ returning full wrapper for model={0} name={1}'.format(self.model, name)) | |
| def wrapped(*args, **kwargs): | |
| """@return The wrapped version of the attr.""" | |
| if kwargs.pop('distributed', False) is True: | |
| dqs = self.distributedQuerySetImpl(model=self.model, using=self._db) | |
| attr = dqs.__getattribute__(name) | |
| else: | |
| # Attempt to infer appropriate shard. | |
| specifier = self._findShardKeySpecification(**kwargs) | |
| if specifier is not None: | |
| self._autoSelectShard(specifier) | |
| attr = super(DistributedManager, self).__getattribute__(name) | |
| result = attr(*args, **kwargs) | |
| return result | |
| return wrapped | |
| elif name in _autoShardable: | |
| def wrapped(*args, **kwargs): | |
| """@return The wrapped version of the attr.""" | |
| # Attempt to infer appropriate shard. | |
| specifier = self._findShardKeySpecification(**kwargs) | |
| if specifier is not None: | |
| self._autoSelectShard(specifier) | |
| attr = super(DistributedManager, self).__getattribute__(name) | |
| result = attr(*args, **kwargs) | |
| return result | |
| return wrapped | |
| else: | |
| return attr | |
| class DistributedModel(Model): | |
| """ | |
| Contains a special save method to automatically infer shard based on | |
| self.user_id when available. | |
| """ | |
| class Meta: | |
| """Default Meta.""" | |
| abstract = True | |
| app_label = 'main' | |
| def save(self, *args, **kw): | |
| """ | |
| Special save method to automatically infer shard based on self.user_id | |
| when available. | |
| """ | |
| # from sh_util.sharding import ShardedResource | |
| # | |
| # if hasattr(self, 'user_id'): | |
| # self._state.db = 'shard_{0}'.format( | |
| # ShardedResource.userIdToPhysicalShardId(self.user_id) | |
| # ) | |
| return super(DistributedModel, self).save(*args, **kw) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """Django model utilities.""" | |
| __author__ = 'Jay Taylor [@jtaylor]' | |
| from django.db.models.fields import IntegerField | |
| class BigIntegerField(IntegerField): | |
| """Bigint field (not compatible with Oracle).""" | |
| empty_strings_allowed = False | |
| def get_internal_type(self): | |
| """Override IntegerFields implementation.""" | |
| return 'BigIntegerField' | |
| def db_type(self, connection): | |
| """Override IntegerFields implementation.""" | |
| # NB: this won't work with Oracle. | |
| return 'bigint' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment