Skip to content

Instantly share code, notes, and snippets.

@jaytaylor
Created October 14, 2013 22:29
Show Gist options
  • Select an option

  • Save jaytaylor/6983333 to your computer and use it in GitHub Desktop.

Select an option

Save jaytaylor/6983333 to your computer and use it in GitHub Desktop.
Django auto-sharding core.
# encoding: utf-8
"""
Exports for main models and also a little bit of injection magic for
django.contrib.auth.models.User.
"""
__author__ = 'Jay Taylor [@jtaylor]'
# NB: ShardedUser and ShardedUserManager transparently extend the django User and
# UserManager classes and add distributed query support.
from django.contrib.auth import models as _djangoModels
from .MainModel import MainModel, MainManager
# Keep a backup reference to the original User and UserManager classes.
_djangoModels._User = _djangoModels.User
_djangoModels._UserManager = _djangoModels.UserManager
#class _DebugModelState(object):
# def __init__(self, db=None):
# self._db = db
# self.adding = True
#
# from sh_util import findVariableByNameInFrame
#
# self.parent = findVariableByNameInFrame('self')
#
# def _setDb(self, db):
# if self.parent is not None and hasattr(self.parent, '_meta') and \
# hasattr(self.parent._meta, 'db_table') and \
# (self.parent._meta.db_table == 'main_contact' or
# self.parent._meta.db_table == 'main_extendeduser'):
# from django_util.log_errors import print_stack
# print ':::::::: Setting db to {0} for {1}={2}'.format(
# db,
# self.parent._meta.db_table if self.parent is not None \
# else 'PARENTISNONE',
# self.parent.id if self.parent is not None else 'PARENTISNONE'
# )
# print_stack()
# print '------'
# self._db = db
#
# def _getDb(self):
# return self._db
#
# db = property(_getDb, _setDb)
#
#from django.db.models import base as _base
#_base.ModelState = _DebugModelState
class ShardedUserManager(_djangoModels._UserManager, MainManager):
"""Override default django User manager."""
class ShardedUser(_djangoModels._User):
"""ShUser extends the base django.contrib.auth.models.User class."""
class Meta(object):
"""
Part of the django models interface.
Here is where ShUser is designated as a simple proxy of the django User
class.
"""
# ShUser is a proxy of the django User class.
# @see https://docs.djangoproject.com/en/dev/topics/db/models/
proxy = True
objects = ShardedUserManager()
def delete(self, *args, **kw):
"""Custom delete."""
from sh_util.db.data import deleteUser
deleteUser(self.id, self._state.db)
# Inject Sh* proxy classes in place of the django models.
_djangoModels.User = ShardedUser
_djangoModels.UserManager = ShardedUserManager
# -*- coding: utf-8 -*-
"""LogicalShard model."""
__author__ = 'Jay Taylor [@jtaylor]'
from django.db.models import (
CharField,
DateTimeField,
IntegerField,
TextField,
)
from .MainModel import MainModel
from sh_util.db import db_query
import settings
class LogicalShard(MainModel):
"""LogicalShard model."""
class Meta:
"""Default Meta."""
app_label = ''
db_table = 'LogicalShard'
physicalShardId = IntegerField()
status = CharField(max_length=64, null=False, blank=True, default='')
note = TextField(null=False, blank=True, default='')
createdTs = DateTimeField(auto_now_add=True)
modifiedTs = DateTimeField(auto_now=True, null=True, default=None)
def userIds(self):
"""
Get all the user-ids in a logical shard.
@return list(int) of user-ids.
"""
res = db_query(
'''SELECT "id" FROM "auth_user" WHERE "id" %% {0} = {1}'''.format(settings.NUM_LOGICAL_SHARDS, self.id),
using='shard_{0}'.format(self.physicalShardId)
)
userIds = map(lambda tup: tup[0], res)
return userIds
# encoding: utf-8
"""All models in `main` should extend this class."""
__author__ = 'Jay Taylor [@jtaylor]'
import datetime, simplejson as json
from django.db import transaction
from django.db.models.fields.files import ImageFieldFile
from .sharding import distributed
################################################################################
# Useful for debugging sharding issues w/ ORM: #
################################################################################
# Override django.db.models.base.ModelState
#from django.db.models import base as _base
## Keep reference to original ModelState class.
#_base.ModelStateOrig = _base.ModelState
#
#class ShardAwareModelState(_base.ModelStateOrig):
# """A class for storing instance state. Overrides default django class."""
# def __init__(self, db=None):
# """Invoke parent constructor."""
# super(ShardAwareModelState, self).__init__(db)
#
# def _getDb(self):
# """Return self.db."""
# return self._db
#
# def _setDb(self, db):
# """Set new value for self.db."""
# from sh_util.sharding import ShardedResource
#
## if db == 'default':
## db = ShardedResource.getCurrentShard()
#
# print 'SET DB TO %s' % db
# self._db = db
#
# db = property(_getDb, _setDb)
#
#_base.ModelState = ShardAwareModelState
class BulkInsertable(object):
"""Bulk insertable trait."""
@staticmethod
@transaction.commit_manually
def insertMany(objects, using='default'):
"""
Note: Modified From http://people.iola.dk/olau/python/bulkops.py
Insert list of Django objects in one SQL query. Objects must be
of the same Django model. Note that save is not called and signals
on the model are not raised.
"""
if not objects:
return
import django.db.models
from django.db import connections
con = connections[using]
model = objects[0].__class__
fields = [f for f in model._meta.fields if not isinstance(f, django.db.models.AutoField)]
parameters = []
for o in objects:
parameters.append(
tuple(
f.get_db_prep_save(f.pre_save(o, True), connection=con) for f in fields
)
)
table = model._meta.db_table
columns = ','.join(con.ops.quote_name(f.column) for f in fields)
placeholders = ','.join(('%s',) * len(fields))
#print 'table', table, 'columns', column_names, 'placeholders', \
# placeholders, 'con', con
con.cursor().executemany(
'INSERT INTO "{0}" ({1}) VALUES ({2})'.format(
table,
columns,
placeholders
),
parameters
)
from sh_util.db import db_exec
db_exec('COMMIT', using=using)
#transaction.commit()
class MainManager(distributed.DistributedManager, BulkInsertable):
"""Abstract base class for Managers in app.main."""
class SoftDeleteCopyable(object):
"""Soft-delete deep-copy trait."""
def __deepcopy__(self, memo):
"""Deep copy of a QuerySet doesn't populate the cache."""
import copy
obj = self.__class__(model=self.model, query=self.query, using=self._db)
items = filter(
lambda (k, v): k not in ('model', 'query', '_db'),
self.__dict__.items()
)
for k, v in items:
if k in ('_iter', '_result_cache'):
obj.__dict__[k] = None
else:
obj.__dict__[k] = copy.deepcopy(v, memo)
return obj
class SoftDeleteDistributedQuerySet(SoftDeleteCopyable, distributed.DistributedQuerySet):
"""Soft-delete support for distributed query sets."""
def __init__(self, model=None, query=None, using=None):
"""Automatically add deleted filter in constructor."""
super(SoftDeleteDistributedQuerySet, self).__init__(
model=model,
query=query,
using=using
)
from django.db.models.query_utils import Q
if query is None:
self.query = distributed.DistributedQuery(model)
if 'deleted' in map(lambda field: field.name, model._meta.fields):
self.query.add_q(Q(deleted=False))
def cancel(self):
"""Reset the query."""
clone = self._clone()
clone.query = distributed.DistributedQuery(clone.model)
return clone
class SoftDeleteAutoShardingQuerySet(SoftDeleteCopyable, distributed.AutoShardingQuerySet):
"""Soft-delete support for auto-sharding query sets."""
def __init__(self, model=None, query=None, using=None):
"""Automatically add deleted filter in constructor."""
super(SoftDeleteAutoShardingQuerySet, self).__init__(model=model, query=query, using=using)
if query is None:
from django.db.models.query_utils import Q
from django.db.models import sql
self.query = sql.Query(model)
if 'deleted' in map(lambda field: field.name, model._meta.fields):
self.query.add_q(Q(deleted=False))
def cancel(self):
"""Reset the query."""
from django.db.models import sql
clone = self._clone()
clone.query = sql.Query(clone.model)
return clone
class SoftDeleteManager(distributed.DistributedManager, BulkInsertable):
"""
Manager base class for models which support soft delete (via `deleted`
field).
"""
def __init__(self, *args, **kw):
"""Override default distributedQuerySetImpl."""
super(SoftDeleteManager, self).__init__(*args, **kw)
self.distributedQuerySetImpl = SoftDeleteDistributedQuerySet
self.autoShardingQuerySetImpl = SoftDeleteAutoShardingQuerySet
def deleted(self, **kw):
"""Returns QuerySet with delete filter."""
querySet = self.all(
**({'distributed': kw['distributed']} if 'distributed' in kw else {})
).cancel()
return querySet.filter(deleted=True)
def all_with_deleted(self, using=None, **kw):
"""Return QuerySet without delete filter."""
querySet = self.all(
**({'distributed': kw['distributed']} if 'distributed' in kw else {})
).cancel()
if using is not None:
querySet._db = using
# Check for get_or_create_with_deleted override and if one is found
# replace the QuerySet's get_or_create with it.
if hasattr(self, 'get_or_create_with_deleted'):
querySet.get_or_create = self.get_or_create_with_deleted
return querySet
##
# NB:
# This turned out to be unnecesary, but it was sucha PITA to implement I'm
# leaving it here commented out.
################################################################################
#def _attrIsSubclassOfField(instance, attrName):
# """
# @return True if the named attribute on the instance's class is a subclass
# of Field, otherwise False.
# """
# from django.db.models.fields import Field
#
# if not hasattr(instance.__class__, attrName):
# return False
#
# attr = getattr(instance.__class__, attrName)
#
# if issubclass(attr.__class__, Field) or (
# hasattr(attr, 'field') and
# issubclass(getattr(attr, 'field').__class__, Field)
# ):
# return True
#
# return False
#_ignoreAttrs = ('_meta', '_state', 'pk', 'id',)
_jsonIgnoreFields = (
'objects',
)
class MainModel(distributed.DistributedModel):
"""Abstract base class for Models in app.main."""
class Meta:
"""Default Meta."""
abstract = True
app_label = 'main'
##
# NB:
# This turned out to be unnecesary, but it was sucha PITA to implement I'm
# leaving it here commented out.
################################################################################
# def __getattribute__(self, k):
# """d"""
# from django.db.models.fields import Field
#
# if k.startswith('__') or k in _ignoreAttrs:
# return super(MainModel, self).__getattribute__(k)
#
# print 'hi %s hasattr?%s' % (k, hasattr(self.__class__, k))
#
# if _attrIsSubclassOfField(self, k) is True:
# from sh_util.sharding import ShardedConnection
#
# with ShardedConnection(self._state.db) as _:
# print 'HIIIIIIIIIIIIIIIIIIIIIII!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
# return super(MainModel, self).__getattribute__(k)
#
# else:
# out = super(MainModel, self).__getattribute__(k)
# print type(out), out.__class__
# return super(MainModel, self).__getattribute__(k)
@property
def __json__(self):
"""Model to simple dict which can be serialized to JSON."""
attrs = dict(
filter(lambda (k, v): not k.startswith('_'), self.__dict__.items())
)
# Find and add any *_str attribute string aliases.
for attr in attrs.keys():
probe = '{0}_str'.format(attr)
if probe not in attrs and hasattr(self, probe) and not callable(getattr(self, probe)):
attrs[probe] = getattr(self, probe)
# Convert things into JSON when we know how.
for k in attrs.keys():
tK = type(attrs[k])
#print 'k={0}, type(k)={1}, attrs[k]={2}'.format(k, tK, attrs[k])
if attrs[k] is None:
# Don't include nulls.
del attrs[k]
elif tK is int or tK is long:
# All numbers will be coerced into strings.
attrs[k] = str(attrs[k])
elif hasattr(attrs[k], '__json__') and callable(attrs[k].__json__):
attrs[k] = attrs[k].__json__()
# If it came back as just a string, let's seeing if it's JSON..
if isinstance(attrs[k], str) or isinstance(attrs[k], unicode):
try:
attrs[k] = json.loads(attrs[k].__json__())
except json.decoder.JSONDecodeError:
pass
elif hasattr(attrs[k], '__json__'):
# Not callable.
attrs[k] = json.loads(attrs[k].__json__)
elif tK is datetime.datetime:
attrs[k] = attrs[k].strftime('%Y-%m-%dT%H:%M:%S')
elif tK is datetime.date:
attrs[k] = attrs[k].strftime('%Y-%m-%d')
elif tK is ImageFieldFile:
attrs[k] = attrs[k].url if attrs[k] else ''
return attrs
# -*- coding: utf-8 -*-
"""Django model utilities."""
__author__ = 'Jay Taylor [@jtaylor]'
def generateDistributedCachingProperty(
idFieldName,
modelClassName,
basePath='main.models'
):
"""
@return a property of getter/setter functions which will do a distributed
get on the requested model class.
"""
from sh_util import toId
def getter(self):
"""Generated function."""
from sh_util import dynImport
model = dynImport('main.models.{0}'.format(modelClassName))
cacheVarName = '__{0}'.format(modelClassName)
if cacheVarName not in self.__dict__:
setattr(self, cacheVarName, None)
cache = getattr(self, cacheVarName)
currentId = getattr(self, idFieldName)
if currentId is None:
return None
if cache is not None and cache.id == currentId:
return cache
cache = model.objects.get(distributed=True, id=currentId)
setattr(self, cacheVarName, cache)
return cache
return property(
getter,
lambda self, value: setattr(self, idFieldName, toId(value))
)
# encoding: utf-8
"""Distributed trait for django models."""
__author__ = 'Jay Taylor [@jtaylor]'
import logging
import settings
from django.db import connections
from django.db.models.sql.constants import SINGLE
from django.db.models.sql.datastructures import EmptyResultSet
from django.db.models import Manager, Model
from django.db.models.query import QuerySet
from django.db.models.sql.compiler import SQLCompiler as SqlCompiler, SQLAggregateCompiler as SqlAggregateCompiler
from django.db.models.sql.query import Query
from django.db.models.query import get_klass_info, get_cached_row
from django.db.models.sql.constants import MULTI
from sh_util.db import getRealShardConnectionName, connections as shConnections
from sh_util.db.distributed import distributedSelect, pgInitializeDbLinks
from sh_util.functional import memoize
## NB: ``Distributed`` class is deprecated
#_tailOperators = {
# 'eq': '=',
# 'ne': '!=',
# 'gt': '>',
# 'gte': '>=',
# 'lt': '<',
# 'lte': '<=',
# 'in': 'IN',
# 'nin': 'NOT IN',
# 'contains': 'LIKE',
# 'like': 'LIKE',
#}
#
#class Distributed(object):
# """Distributed trait."""
#
# @classmethod
# def buildFromDict(clazz, **kwargs):
# """Build a contact parent object from a dict."""
# db = kwargs.pop('shard')
#
# obj = clazz(**kwargs)
#
# obj._state.db = db
#
# return obj
#
# @staticmethod
# def _parseFilter(**kw):
# """Convert keyword arguments into a where clause."""
# keyOpValueList = []
#
# for key in kw:
# value = kw[key]
#
# operator = '='
#
# for tail in _tailOperators:
# if key.endswith('__{0}'.format(tail)):
# operator = _tailOperators[tail]
#
# key = key[0:key.rindex('__')]
#
# if tail == 'contains':
# value = '%{0}%'.format(value)
#
# break
#
# if operator is '=' and (value is None or type(value) is bool):
# operator = 'IS'
# elif operator is '!=' and (value is None or type(value) is bool):
# operator = 'IS NOT'
#
# if value is None:
# value = 'NULL'
# elif isinstance(value, int) or isinstance(value, long):
# value = str(value)
# elif isinstance(value, bool):
# value = str(value).upper()
# elif operator == 'IN' and hasattr(value, '__iter__'):
# if len(value) > 0 and (isinstance(value[0], int) or \
# isinstance(value[0], long) or isinstance(value[0], bool)):
# value = '({0})'.format(', '.join(map(str, value)))
# else:
# value = '({0})'.format(', '.join(map(
# lambda s: "'{0}'".format(str(s)),
# value
# )))
# else:
# print 'yEP'
# value = "'{0}'".format(value.strip("'"))
#
# keyOpValueList.append((key, operator, value))
#
# out = ' AND '.join(map(
# lambda (key, op, value): '"{0}" {1} {2}' \
# .format(key, op, value),
# keyOpValueList
# ))
# print 'OUT=%s' % out
# return out
#
# @classmethod
# def distributedFilter(clazz, **kw):
# """Distributed filter."""
# where = Distributed._parseFilter(**kw)
# if len(where):
# where = 'WHERE {0}'.format(where)
#
# table = clazz()._meta.db_table
#
# data = _evaluatedDistributedSelect(
# '''SELECT * FROM "{0}" {1}'''.format(table, where),
# asDict=True,
# includeShardInfo=True
# )
#
# def dataToObj(data):
# """Convert data to a class instance."""
# # Build and return obj.
# db = data.pop('shard')
#
# obj = clazz(**data)
#
# obj._state.db = db
#
# return obj
#
# return map(dataToObj, data)
#
# @classmethod
# def distributedGet(clazz, **kw):
# """Distributed get."""
# where = Distributed._parseFilter(**kw)
# if len(where):
# where = 'WHERE {0}'.format(where)
#
# table = clazz()._meta.db_table
#
# data = _evaluatedDistributedSelect(
# '''SELECT * FROM "{0}" {1}'''.format(table, where),
# asDict=True,
# includeShardInfo=True
# )
#
# if len(data) != 1:
# raise clazz.DoesNotExist()
#
# def dataToObj(data):
# """Convert data to a class instance."""
# # Build and return obj.
# db = data.pop('shard')
#
# obj = clazz(**data)
#
# obj._state.db = db
#
# return obj
#
# return dataToObj(data[0])
class DistributedSqlCompiler(SqlCompiler):
"""Distributed pg compiler."""
def as_sql(
self,
with_limits=True,
with_col_aliases=False,
includeShardInfo=True
):
"""Creates the SQL for this query."""
sql, args = super(DistributedSqlCompiler, self).as_sql(
with_limits=with_limits,
with_col_aliases=with_col_aliases
)
distributedSql, distributedArgs = distributedSelect(
sql,
args,
includeShardInfo=includeShardInfo
)
return (distributedSql, distributedArgs)
def execute_sql(self, result_type=MULTI):
"""If persistent dblink is enabled, ensure persistent dblinks are initialized."""
self.using = getRealShardConnectionName(self.using)
self.connection = shConnections()[self.using]
if getattr(settings, 'SH_UTIL_USE_PERSISTENT_DBLINK', False):
pgInitializeDbLinks(self.using)
return super(DistributedSqlCompiler, self).execute_sql(result_type)
class DistributedSqlAggregateCompiler(
SqlAggregateCompiler,
DistributedSqlCompiler
):
""""Distributed aggregate SQL compiler."""
def as_sql(self, qn=None):
"""Only use the AggregateCompiler when self.subquery has been set."""
# Only use aggregate compiler when self.subquery exists.
if hasattr(self, 'subquery'):
# Get aggregate SQL from SqlAggregateCompiler.
sql, args = super(DistributedSqlAggregateCompiler, self).as_sql(qn=qn)
distributedSql, distributedArgs = distributedSelect(sql, args, includeShardInfo=False)
else:
distributedSql, distributedArgs = DistributedSqlCompiler.as_sql(self, False, False, False)
return (distributedSql, distributedArgs)
# Inject our compilers into django.db.models.sql.compiler so they will be
# available to django.
from django.db.models.sql import compiler as _compiler
_compiler.DistributedSqlCompiler = DistributedSqlCompiler
_compiler.DistributedSqlAggregateCompiler = DistributedSqlAggregateCompiler
class DistributedQuery(Query):
"""Distributed Query."""
compiler = 'DistributedSqlCompiler'
aggregateCompiler = 'DistributedSqlAggregateCompiler'
def get_count(self, using):
"""
Performs a COUNT() query using the current filter constraints.
Pasted this from django 1.4 and then override the compiler inside.
"""
obj = self.clone()
if len(self.select) > 1 or self.aggregate_select or (self.distinct and self.distinct_fields):
# If a select clause exists, then the query has already started to
# specify the columns that are to be returned.
# In this case, we need to use a subquery to evaluate the count.
from django.db.models.sql.subqueries import AggregateQuery
subquery = obj
subquery.clear_ordering(True)
subquery.clear_limits()
obj = AggregateQuery(obj.model)
# Set compiler to be the distributed aggregate query compiler.
obj.compiler = DistributedQuery.aggregateCompiler
try:
obj.add_subquery(subquery, using=using)
except EmptyResultSet:
# add_subquery evaluates the query, if it's an EmptyResultSet
# then there are can be no results, and therefore there the
# count is obviously 0
return 0
else:
# Set compiler to be the distributed aggregate query compiler.
obj.compiler = DistributedQuery.aggregateCompiler
obj.add_count_column()
number = obj.get_aggregation(using=using)[None]
# Apply offset and limit constraints manually, since using LIMIT/OFFSET
# in SQL (in variants that provide them) doesn't change the COUNT
# output.
number = max(0, number - self.low_mark)
if self.high_mark is not None:
number = min(number, self.high_mark - self.low_mark)
return number
def get_aggregation(self, using):
"""
Returns the dictionary with the values of the existing aggregations.
Copied straight out of django 1.4 django.db.models.sql.query.Query
in order to override which compiler is used.
"""
if not self.aggregate_select:
return {}
# If there is a group by clause, aggregating does not add useful
# information but retrieves only the first row. Aggregate
# over the subquery instead.
if self.group_by is not None:
from django.db.models.sql.subqueries import AggregateQuery
query = AggregateQuery(self.model)
# Set compiler to be the distributed aggregate query compiler.
query.compiler = DistributedQuery.aggregateCompiler
obj = self.clone()
# Remove any aggregates marked for reduction from the subquery
# and move them to the outer AggregateQuery.
for alias, aggregate in self.aggregate_select.items():
if aggregate.is_summary:
query.aggregate_select[alias] = aggregate
del obj.aggregate_select[alias]
try:
query.add_subquery(obj, using)
except EmptyResultSet:
return dict(
(alias, None)
for alias in query.aggregate_select
)
else:
query = self
self.select = []
self.default_cols = False
self.extra = {}
self.remove_inherited_models()
# Set compiler to be the distributed aggregate query compiler.
query.compiler = DistributedQuery.aggregateCompiler
query.clear_ordering(True)
query.clear_limits()
query.select_for_update = False
query.select_related = False
query.related_select_cols = []
query.related_select_fields = []
result = query.get_compiler(using).execute_sql(SINGLE)
if result is None:
result = [None for _ in query.aggregate_select.items()]
return dict([
(
alias,
self.resolve_aggregate(
val,
aggregate,
connection=connections[using]
)
)
for (alias, aggregate), val
in zip(query.aggregate_select.items(), result)
])
@memoize
def _numShards():
"""Cached number of physical shards."""
from sh_util.sharding import ShardedResource
return len(ShardedResource.allShardConnectionNames())
class DistributedQuerySet(QuerySet):
"""
Distributed Query Set.
Supports a limited-subset of the usual QuerySet functionality.
"""
def __init__(self, model=None, query=None, using=None):
"""Override query's compiler."""
if query is None:
query = DistributedQuery(model=model)
super(DistributedQuerySet, self).__init__(
model=model,
query=query,
using=using
)
def iterator(self):
"""Iterate and yield each of the resulting rows."""
def buildFromArgsOrKw(clazz, *args, **kwargs):
"""Initialize a db class instance from a args or kwargs."""
# Must be one or the other.
assert (len(args) == 0 and len(kwargs) > 0) or (len(args) > 0 and len(kwargs) == 0)
# NB: Extreme care must be taken here when attempting to ascertain
# the shard value.
if len(args) > 0:
if _numShards() == 1:
# This is okay, but only if there is only 1 physical shard.
shard = self.db
else:
shard = args[-1]
args = args[0:-1]
obj = clazz(*args)
else:
if _numShards() == 1:
# This is okay, but only if there is only 1 physical shard.
shard = self.db
else:
shard = kwargs.pop('shard')
obj = clazz(**kwargs)
# Store the source database of the object.
obj._state.db = shard
# This object came from the database; it's not being added.
obj._state.adding = False
return obj
if connections[self.db].features.supports_select_related:
fillCache = self.query.select_related
else:
fillCache = False
requested = fillCache if isinstance(fillCache, dict) else None
maxDepth = self.query.max_depth
extraSelect = self.query.extra_select.keys()
aggregateSelect = self.query.aggregate_select.keys()
onlyLoad = self.query.get_loaded_field_names()
def identifyLoadFields(onlyLoad):
"""Determine which fields will be loaded."""
loadFields = []
if onlyLoad:
for field, model in self.model._meta.get_fields_with_model():
if model is None:
model = self.model
try:
if field.name in onlyLoad[model]:
# Add a field that has been explicitly included.
loadFields.append(field.name)
except KeyError:
# Model wasn't explicitly listed in the onlyLoad table
# Therefore, we need to load all fields from this model.
loadFields.append(field.name)
return loadFields
loadFields = identifyLoadFields(onlyLoad)
if loadFields and not fillCache:
skip = set()
initList = []
for field in self.model._meta.fields:
if field.name not in loadFields:
skip.add(field.attname)
else:
initList.append(field.attname)
if _numShards() > 1:
initList.append('shard')
else:
skip = None
initList = None
numFields = len(loadFields or self.model._meta.fields)
indexStart = len(extraSelect)
# + 1 for shardId.
aggregateStart = indexStart + numFields + 1
compiler = self.query.get_compiler(using=self.db)
if fillCache:
klassInfo = get_klass_info(
self.model,
max_depth=maxDepth,
requested=requested,
only_load=onlyLoad
)
for row in compiler.results_iter():
if fillCache:
obj, _ = get_cached_row(
row,
indexStart,
self.db,
klassInfo,
offset=len(aggregateSelect)
)
elif skip is not None:
# Build from dict of args:
rowData = row[indexStart:aggregateStart]
obj = buildFromArgsOrKw(
self.model,
**dict(zip(initList, rowData))
)
else:
# Build from list of args.
obj = buildFromArgsOrKw(
self.model,
*row[indexStart:aggregateStart]
)
if extraSelect:
for i, k in enumerate(extraSelect):
setattr(obj, k, row[i])
# Add the aggregates to the model.
if aggregateSelect:
for i, aggregate in enumerate(aggregateSelect):
setattr(obj, aggregate, row[i + aggregateStart])
yield obj
def get_query_set(self):
"""Return a distributed query set."""
return DistributedQuerySet(
model=self.model,
query=self.query,
using=self._db
)
# NB: MOVE THIS TO SOME CONFIGURABLE PLACE
def shardKeyFinder(querySet, **kw):
"""
Part of the auto-sharding interface. This function takes a QuerySet and
dict of keyword arguments and returns the shard-key value if possible,
otherwise None.
"""
# Get at the model's table name.
modelTable = querySet.model._meta.db_table if hasattr(querySet.model, '_meta') else querySet.model()._meta.db_table
if modelTable == 'auth_user':
# Search for "id" or "pk":
#logging.debug(u'Searching for id/pk key in {0}'.format(kw))
spec = kw.get('id', kw.get('pk', None))
else:
# Otherwise.. user*:
#logging.debug(u'Searching for user-id key in {0}'.format(kw))
spec = kw.get('user', kw.get('user_id', kw.get('userId', None)))
return spec
# Operations which can be (at least in theory) automatically distributed.
_distributable = (
'none',
'all',
'count',
'dates',
'distinct',
'extra',
'get',
'filter',
'aggregate',
'annotate',
'complex_filter',
'exclude',
'in_bulk',
'iterator',
'latest',
'order_by',
'values',
'values_list',
'reverse',
'defer',
'only',
'exists',
'raw',
)
# Operations for which the correct shard may be automatically inferred by id or
# user-id.
_autoShardable = (
'create',
'get_or_create',
'get_or_create_with_deleted',
)
class AutoSelectShard(object):
"""Automatic shard selection trait."""
def _findShardKeySpecification(self, **kw):
"""
Disover the user-id specifier. Uses **kw so the original kws wont be
messed with.
@return user-id value or None if not found.
"""
spec = shardKeyFinder(self, **kw)
if spec is not None:
if isinstance(spec, int) or isinstance(spec, long):
return spec
elif hasattr(spec, 'id'):
return spec.id
elif isinstance(spec, str) or isinstance(spec, unicode) and spec.isdigit():
return spec
return None
def _autoSelectShard(self, userId):
"""Automatically use the correct shard based on user-id."""
from sh_util.sharding import ShardedResource
physicalShardId = ShardedResource.userIdToPhysicalShardId(userId)
self._db = 'shard_{0}'.format(physicalShardId)
logging.debug(
u'DistributedManager :: {0} :: Automatically selected db={1} for user-id={2}'.format(
self.model._meta.db_table if hasattr(self.model, '_meta') and hasattr(self.model._meta, 'db_table') else 'unknown',
self._db,
userId
)
)
class AutoShardingQuerySet(AutoSelectShard, QuerySet):
"""Automatically selects shard if a shard-key specifier is found."""
def __getattribute__(self, name):
"""
Proxy for all attribute requests to inject distributed query request
detection.
"""
attr = super(AutoShardingQuerySet, self).__getattribute__(name)
# Allow any attribute contained in auto-shardable or distributable to
# have the destination connection automatically selected.
if name in _autoShardable or name in _distributable:
def wrapped(*args, **kwargs):
"""@return The wrapped version of the attr."""
# Attempt to infer appropriate shard.
specifier = self._findShardKeySpecification(**kwargs)
if specifier is not None:
self._autoSelectShard(specifier)
attr = super(AutoShardingQuerySet, self).__getattribute__(name)
result = attr(*args, **kwargs)
return result
return wrapped
else:
return attr
class DistributedManager(AutoSelectShard, Manager):
"""
Distributed model manager.
NB: One way to always disable the automatic shard selector is to use
<user-or-id-field>__in=(x,) instead of <user-or-id-field>=x.
"""
def __init__(self, *args, **kw):
"""Set default dist queryset implementation."""
super(DistributedManager, self).__init__(*args, **kw)
# These can be swapped out by child classes.
self.distributedQuerySetImpl = DistributedQuerySet
self.autoShardingQuerySetImpl = AutoShardingQuerySet
def get_or_create(self, **kwargs):
"""
Looks up an object with the given kwargs, creating one if necessary.
Returns a tuple of (object, created), where created is a boolean
specifying whether an object was created.
"""
from django.db import IntegrityError
import sys
assert kwargs, 'get_or_create() must be passed at least one keyword argument'
defaults = kwargs.pop('defaults', {})
lookup = kwargs.copy()
for f in self.model._meta.fields:
if f.attname in lookup:
lookup[f.name] = lookup.pop(f.attname)
try:
return self.get(**lookup), False
except self.model.DoesNotExist:
try:
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults)
obj = self.create(**params)
return obj, True
except IntegrityError, e:
exc_info = sys.exc_info()
connections[self._db]._rollback()
try:
return self.get(**lookup), False
except self.model.DoesNotExist:
# Re-raise the IntegrityError with its original traceback.
raise exc_info[1], None, exc_info[2]
def get_query_set(self):
"""Get an auto-sharding capable query set instance."""
logging.debug(u'Returning auto-sharding query set for model={0}'.format(self.model))
return self.autoShardingQuerySetImpl(model=self.model, using=self._db)
# @TODO .exists is useless here, can likely be removed.
def exists(self):
"""More expensive but working form of .exists()."""
return self.count() > 0
def __getattribute__(self, name):
"""
Proxy for all attribute requests to inject distributed query request
detection.
"""
attr = super(DistributedManager, self).__getattribute__(name)
if name in _distributable and callable(attr):
logging.debug(u'__getattribute__ returning full wrapper for model={0} name={1}'.format(self.model, name))
def wrapped(*args, **kwargs):
"""@return The wrapped version of the attr."""
if kwargs.pop('distributed', False) is True:
dqs = self.distributedQuerySetImpl(model=self.model, using=self._db)
attr = dqs.__getattribute__(name)
else:
# Attempt to infer appropriate shard.
specifier = self._findShardKeySpecification(**kwargs)
if specifier is not None:
self._autoSelectShard(specifier)
attr = super(DistributedManager, self).__getattribute__(name)
result = attr(*args, **kwargs)
return result
return wrapped
elif name in _autoShardable:
def wrapped(*args, **kwargs):
"""@return The wrapped version of the attr."""
# Attempt to infer appropriate shard.
specifier = self._findShardKeySpecification(**kwargs)
if specifier is not None:
self._autoSelectShard(specifier)
attr = super(DistributedManager, self).__getattribute__(name)
result = attr(*args, **kwargs)
return result
return wrapped
else:
return attr
class DistributedModel(Model):
"""
Contains a special save method to automatically infer shard based on
self.user_id when available.
"""
class Meta:
"""Default Meta."""
abstract = True
app_label = 'main'
def save(self, *args, **kw):
"""
Special save method to automatically infer shard based on self.user_id
when available.
"""
# from sh_util.sharding import ShardedResource
#
# if hasattr(self, 'user_id'):
# self._state.db = 'shard_{0}'.format(
# ShardedResource.userIdToPhysicalShardId(self.user_id)
# )
return super(DistributedModel, self).save(*args, **kw)
# -*- coding: utf-8 -*-
"""Django model utilities."""
__author__ = 'Jay Taylor [@jtaylor]'
from django.db.models.fields import IntegerField
class BigIntegerField(IntegerField):
"""Bigint field (not compatible with Oracle)."""
empty_strings_allowed = False
def get_internal_type(self):
"""Override IntegerFields implementation."""
return 'BigIntegerField'
def db_type(self, connection):
"""Override IntegerFields implementation."""
# NB: this won't work with Oracle.
return 'bigint'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment