Created
December 16, 2015 06:42
-
-
Save ziplus4/3bf8cc14541a16c65206 to your computer and use it in GitHub Desktop.
flask, sqlalchemy sample : sharding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf8 -*- | |
import re | |
from flask import Flask | |
from flask_sqlalchemy import SQLAlchemy as BaseSQLAlchemy | |
from flask_sqlalchemy import _SignallingSession as BaseSignallingSession | |
from flask_sqlalchemy import orm, partial, get_state | |
from datetime import datetime | |
class _BindingKeyPattern(object): | |
def __init__(self, db, pattern): | |
self.db = db | |
self.raw_pattern = pattern | |
self.compiled_pattern = re.compile(pattern) | |
self._shard_keys = None | |
def __repr__(self): | |
return "%s<%s>" % (self.__class__.__name__, self.raw_pattern) | |
def match(self, key): | |
return self.compiled_pattern.match(key) | |
def get_shard_key(self, hash_num): | |
if self._shard_keys is None: | |
self._shard_keys = [key for key, value in self.db.app.config['SQLALCHEMY_BINDS'].iteritems() if self.compiled_pattern.match(key)] | |
self._shard_keys.sort() | |
return self._shard_keys[hash_num % len(self._shard_keys)] | |
class _BoundSection(object): | |
def __init__(self, db_session_cls, name): | |
self.db_session = db_session_cls() | |
self.name = name | |
def __enter__(self): | |
self.db_session.push_binding(self.name) | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.db_session.pop_binding() | |
self.db_session.close() | |
class _SignallingSession(BaseSignallingSession): | |
def __init__(self, *args, **kwargs): | |
BaseSignallingSession.__init__(self, *args, **kwargs) | |
self._binding_keys = [] | |
self._binding_key = None | |
def push_binding(self, key): | |
self._binding_keys.append(self._binding_key) | |
self._binding_key = key | |
def pop_binding(self): | |
self._binding_key = self._binding_keys.pop() | |
def get_bind(self, mapper, clause=None): | |
binding_key = self.__find_binding_key(mapper) | |
if binding_key is None: | |
return BaseSignallingSession.get_bind(self, mapper, clause) | |
else: | |
state = get_state(self.app) | |
return state.db.get_engine(self.app, bind=binding_key) | |
def __find_binding_key(self, mapper): | |
if mapper is None: # 맵퍼 없음 | |
return self._binding_key | |
else: | |
mapper_info = getattr(mapper.mapped_table, 'info', {}) | |
mapped_binding_key = mapper_info.get('bind_key') | |
if mapped_binding_key: # 맵핑된 바인딩 키 존재 | |
if type(mapped_binding_key) is str: # 정적 바인딩 | |
return mapped_binding_key | |
else: # 동적 바인딩 | |
if mapped_binding_key.match(self._binding_key): # 현재 바인딩 | |
return self._binding_key | |
else: # 푸쉬된 바인딩 | |
for pushed_binding_key in reversed(self._binding_keys): | |
if pushed_binding_key and mapped_binding_key.match(pushed_binding_key): | |
return pushed_binding_key | |
else: | |
raise Exception('NOT_FOUND_MAPPED_BINDING:%s CURRENT_BINDING:%s PUSHED_BINDINGS:%s' % (repr(mapped_binding_key), repr(self._binding_key), repr(self._binding_keys[1:]))) | |
else: # 맵핑된 바인딩 키가 없으면 디폴트 바인딩 | |
return self._binding_key | |
class SQLAlchemy(BaseSQLAlchemy): | |
def BindingKeyPattern(self, pattern): | |
return _BindingKeyPattern(self, pattern) | |
def binding(self, key): | |
return _BoundSection(self.session, key) | |
def create_scoped_session(self, options=None): | |
if options is None: | |
options = {} | |
scopefunc=options.pop('scopefunc', None) | |
return orm.scoped_session( | |
partial(_SignallingSession, self, **options), scopefunc=scopefunc | |
) | |
def get_binds(self, app=None): | |
retval = BaseSQLAlchemy.get_binds(self, app) | |
bind = None | |
engine = self.get_engine(app, bind) | |
tables = self.get_tables_for_bind(bind) | |
retval.update(dict((table, engine) for table in tables)) | |
return retval | |
def get_tables_for_bind(self, bind=None): | |
result = [] | |
for table in self.Model.metadata.tables.itervalues(): | |
table_bind_key = table.info.get('bind_key') | |
if table_bind_key == bind: | |
result.append(table) | |
else: | |
if bind: | |
if type(table_bind_key) is _BindingKeyPattern and table_bind_key.match(bind): | |
result.append(table) | |
elif type(table_bind_key) is str and table_bind_key == bind: | |
result.append(table) | |
return result | |
app = Flask(__name__) | |
db = SQLAlchemy(app) | |
class Notice(db.Model): | |
__bind_key__ = 'global' | |
id = db.Column(db.Integer, primary_key=True) | |
msg = db.Column(db.String, nullable=False) | |
ctime = db.Column(db.DateTime, default=datetime.now(), nullable=False) | |
def __repr__(self): | |
return "%s<id=%d,msg='%s'>" % (self.__class__.__name__, self.id, self.msg) | |
class User(db.Model): | |
__bind_key__ = db.BindingKeyPattern('[^_]+_user_\d\d') | |
id = db.Column(db.Integer, primary_key=True) | |
nickname = db.Column(db.String(80), unique=True) | |
login_logs = db.relationship(lambda: LoginLog, backref='owner') | |
def __repr__(self): | |
return "%s<id=%d, nickname='%s'>" % (self.__class__.__name__, self.id, self.nickname) | |
@classmethod | |
def get_shard_key(cls, nickname): | |
return cls.__bind_key__.get_shard_key(hash(nickname)) | |
class LoginLog(db.Model): | |
__bind_key__ = db.BindingKeyPattern('[^_]+_log') | |
id = db.Column(db.Integer, primary_key=True) | |
user_id = db.Column(db.Integer, db.ForeignKey(User.id)) | |
ctime = db.Column(db.DateTime, default=datetime.now(), nullable=False) | |
if __name__ == '__main__': | |
app.config['SQLALCHEMY_ECHO'] = True | |
app.config['SQLALCHEMY_BINDS'] = { | |
'global': 'sqlite:///./global.db', | |
'master_user_01': 'sqlite:///./master_user_01.db', | |
'master_user_02': 'sqlite:///./master_user_02.db', | |
'slave_user': 'sqlite:///./slave_user.db', | |
'master_log': 'sqlite:///./master_log.db', | |
'slave_log': 'sqlite:///./slave_log.db', | |
} | |
db.drop_all() | |
db.create_all() | |
notice = Notice(msg='NOTICE1') | |
db.session.add(notice) | |
db.session.commit() | |
nickname = 'jaru' | |
with db.binding(User.get_shard_key(nickname)): | |
notice = Notice(msg='NOTICE2') | |
db.session.add(notice) | |
db.session.commit() | |
user = User(nickname=nickname) | |
db.session.add(user) | |
db.session.commit() | |
with db.binding('master_log'): | |
notice = Notice(msg='NOTICE3') | |
db.session.add(notice) | |
db.session.commit() | |
login_log = LoginLog(owner=user) | |
db.session.add(login_log) | |
db.session.commit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment